hv0905/NekoImageGallery

View on GitHub
app/Controllers/search.py

Summary

Maintainability
A
0 mins
Test Coverage
from io import BytesIO
from typing import Annotated, List
from uuid import uuid4, UUID

from PIL import Image
from fastapi import APIRouter, HTTPException
from fastapi.params import File, Query, Path, Depends
from loguru import logger

from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import ServiceProvider
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
                          tags=["Search"])

services: ServiceProvider | None = None  # The service provider will be injected in the webapp initialize


class SearchBasisParams:
    def __init__(self,
                 basis: Annotated[SearchBasisEnum, Query(
                     description="The basis used to search the image.")] = SearchBasisEnum.vision):
        if basis == SearchBasisEnum.ocr and not config.ocr_search.enable:
            raise HTTPException(400, "OCR search is not enabled.")
        self.basis = basis


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
    if not config.storage.method.enabled:
        return resp
    for item in resp.result:
        if item.img.local:
            img_extension = item.img.format or item.img.url.split('.')[-1]
            img_remote_filename = f"{item.img.id}.{img_extension}"
            item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
        if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail):
            thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
            item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
                thumbnail_remote_filename)
    return resp


@search_router.get("/text/{prompt}", description="Search images by text prompt")
async def textSearch(
        prompt: Annotated[
            str, Path(max_length=100, description="The image prompt text you want to search.")],
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
        exact: Annotated[bool, Query(
            description="If using OCR search, this option will require the ocr text contains **exactly** the "
                        "criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
    logger.info("Text search request received, prompt: {}", prompt)
    text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
        else services.transformers_service.get_bert_vector(prompt)
    if basis.basis == SearchBasisEnum.ocr and exact:
        filter_param.ocr_text = prompt
    results = await services.db_context.querySearch(text_vector,
                                                    query_vector_name=services.db_context.vector_name_for_basis(
                                                        basis.basis),
                                                    filter_param=filter_param,
                                                    top_k=paging.count,
                                                    skip=paging.skip)
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.post("/image", description="Search images by image")
async def imageSearch(
        image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
                                     description="The image you want to search.")],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
    fakefile = BytesIO(image)
    img = Image.open(fakefile)
    logger.info("Image search request received")
    image_vector = services.transformers_service.get_image_vector(img)
    results = await services.db_context.querySearch(image_vector,
                                                    top_k=paging.count,
                                                    skip=paging.skip,
                                                    filter_param=filter_param)
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.get("/similar/{image_id}",
                   description="Search images similar to the image with given id. "
                               "Won't include the given image itself in the result.")
async def similarWith(
        image_id: Annotated[UUID, Path(description="The id of the image you want to search.")],
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
    logger.info("Similar search request received, id: {}", image_id)
    results = await services.db_context.querySimilar(search_id=str(image_id),
                                                     top_k=paging.count,
                                                     skip=paging.skip,
                                                     filter_param=filter_param,
                                                     query_vector_name=services.db_context.vector_name_for_basis(
                                                         basis.basis))
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.post("/advanced", description="Search with multiple criteria")
async def advancedSearch(
        model: AdvancedSearchModel,
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
    logger.info("Advanced search request received: {}", model)
    result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
        model: CombinedSearchModel,
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
    if not config.ocr_search.enable:
        raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
                                 "enabled.")
    logger.info("Combined search request received: {}", model)
    result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True)
    calculate_and_sort_by_combined_scores(model, basis, result)
    result = result[:paging.count] if len(result) > paging.count else result
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.get("/random", description="Get random images")
async def randomPick(
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
        seed: Annotated[int | None, Query(
            description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None,
) -> SearchApiResponse:
    logger.info("Random pick request received")
    random_vector = services.transformers_service.get_random_vector(seed)
    result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip,
                                                   filter_param=filter_param)
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
#     raise NotImplementedError()

async def process_advanced_and_combined_search_query(model: AdvancedSearchModel,
                                                     basis: SearchBasisParams,
                                                     filter_param: FilterParams,
                                                     paging: SearchPagingParams,
                                                     is_combined_search=False) -> List[SearchResult]:
    match basis.basis:
        case SearchBasisEnum.ocr:
            positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
            negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
        case SearchBasisEnum.vision:
            positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
            negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
        case _:  # pragma: no cover
            raise NotImplementedError()
    # In order to ensure the query effect of the combined query, modify the actual top_k
    _query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count
    result = await services.db_context.querySimilar(
        query_vector_name=services.db_context.vector_name_for_basis(basis.basis),
        positive_vectors=positive_vectors,
        negative_vectors=negative_vectors,
        mode=model.mode,
        filter_param=filter_param,
        with_vectors=is_combined_search,
        top_k=_query_top_k,
        skip=paging.skip)
    return result


def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
                                          basis: SearchBasisParams,
                                          result: List[SearchResult]) -> None:
    # Use a different method to calculate the extra prompt vector based on the basis
    match basis.basis:
        case SearchBasisEnum.ocr:
            extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt)
        case SearchBasisEnum.vision:
            extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt)
        case _:  # pragma: no cover
            raise NotImplementedError()
    # Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
    for itm in result:
        match basis.basis:
            case SearchBasisEnum.ocr:
                extra_vector = itm.img.image_vector
            case SearchBasisEnum.vision:
                extra_vector = itm.img.text_contain_vector
            case _:  # pragma: no cover
                raise NotImplementedError()
        if extra_vector is not None:
            similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
            itm.score = (1 + similar_score) * itm.score
    # Finally, sort the result by combined_similar_score
    result.sort(key=lambda i: i.score, reverse=True)