hv0905/NekoImageGallery

View on GitHub
scripts/db_migrations.py

Summary

Maintainability
A
0 mins
Test Coverage
from loguru import logger

from app.Services.provider import ServiceProvider

CURRENT_VERSION = 2

services: ServiceProvider | None = None


async def migrate_v1_v2():
    logger.info("Migrating from v1 to v2...")
    next_id = None
    count = 0
    while True:
        points, next_id = await services.db_context.scroll_points(next_id, count=100)
        for point in points:
            count += 1
            logger.info("[{}] Migrating point {}", count, point.id)
            if point.url.startswith('/'):
                # V1 database assuming all image with '/' as begins is a local image,
                # v2 migrate to a more strict approach
                point.local = True
            await services.db_context.updatePayload(point)  # This will also store ocr_text_lower field, if present
            if point.ocr_text is not None:
                point.text_contain_vector = services.transformers_service.get_bert_vector(point.ocr_text_lower)

        logger.info("Updating vectors...")
        # Update vectors for this group of points
        await services.db_context.updateVectors([t for t in points if t.text_contain_vector is not None])
        if next_id is None:
            break


async def migrate(from_version: int):
    global services
    services = ServiceProvider()
    await services.onload()
    match from_version:
        case 1:
            await migrate_v1_v2()
        case 2:
            logger.info("Already up to date.")
        case _:
            raise ValueError(f"Unknown version {from_version}")