CIMAC-CIDC/cidc-api-gae

View on GitHub
cidc_api/models/migrations.py

Summary

Maintainability
A
3 hrs
Test Coverage
B
82%
import os, traceback
from contextlib import contextmanager
from functools import partial
from typing import Callable, List, NamedTuple

from alembic import op
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.attributes import flag_modified
from google.cloud import storage

from cidc_schemas.migrations import MigrationResult
from cidc_schemas.prism.merger import _get_uuid_path_map, get_source

from .models import (
    TrialMetadata,
    DownloadableFiles,
    CommonColumns,
    UploadJobStatus,
    UploadJobs,
    String,
    Column,
)
from ..shared.gcloud_client import publish_artifact_upload
from ..config.settings import GOOGLE_ACL_DATA_BUCKET, GOOGLE_UPLOAD_BUCKET


class PieceOfWork(NamedTuple):
    do: Callable[[], None]
    undo: Callable[[], None]


class RollbackableQueue:
    """A collection of reversible pieces-of-work."""

    def __init__(self):
        self.tasks = []
        self.done = set()

    def schedule(self, task: PieceOfWork):
        """Add a task to the task queue."""
        self.tasks.append(task)

    def run_all(self):
        """
        Attempt to run all tasks in the queue, rolling back
        successfully completed tasks if a subsequent task fails.
        """
        for i, task in enumerate(self.tasks):
            try:
                task.do()
                self.done.add(i)
            except:
                self.rollback()
                raise

    def rollback(self):
        """
        Undo any work that has been carried out.
        """
        for i, task in enumerate(self.tasks):
            if i in self.done:
                task.undo()
                self.done.discard(i)


@contextmanager
def migration_session():
    session = Session(bind=op.get_bind())
    task_queue = RollbackableQueue()

    try:
        yield session, task_queue
        print("Commiting SQL session...")
        session.commit()
        print("Session commit succeeded.")
    except Exception as e:
        print(f"Encountered exception: {e.__class__}\n{e}")
        print("Running SQL rollback...")
        session.rollback()
        print("SQL rollback succeeded.")
        if task_queue:
            try:
                print("Running GCS rollback...")
                task_queue.rollback()
                print("GCS rollback succeeded.")
            except Exception as e:
                print(f"GCS rollback failed: {e.__class__}\n{e}")
        raise
    finally:
        session.close()


def run_metadata_migration(
    metadata_migration: Callable[[dict], MigrationResult], use_upload_jobs_table: bool
):
    """Migrate trial metadata, upload job patches, and downloadable files according to `metadata_migration`"""
    with migration_session() as (session, task_queue):
        try:
            _run_metadata_migration(
                metadata_migration, use_upload_jobs_table, task_queue, session
            )
        except:
            traceback.print_exc()
            raise


def _select_trials(session: Session) -> List[TrialMetadata]:
    return session.query(TrialMetadata).with_for_update().all()


class AssayUploads(CommonColumns):
    """This model no longer exists in cidc_api.models, but a partial model is required for migrations."""

    __tablename__ = "assay_uploads"
    status = Column(String)


class ManifestUploads(CommonColumns):
    """This model no longer exists in cidc_api.models, but a partial model is required for migrations."""

    __tablename__ = "manifest_uploads"


def _select_successful_assay_uploads(
    use_upload_jobs_table: bool, session: Session
) -> List[UploadJobs]:
    if use_upload_jobs_table:
        return (
            session.query(UploadJobs)
            .filter_by(status=UploadJobStatus.MERGE_COMPLETED.value, multifile=True)
            .with_for_update()
            .all()
        )

    return (
        session.query(AssayUploads)
        .filter_by(status=UploadJobStatus.MERGE_COMPLETED.value)
        .with_for_update()
        .all()
    )


def _select_manifest_uploads(
    use_upload_jobs_table: bool, session: Session
) -> List[UploadJobs]:
    if use_upload_jobs_table:
        return (
            session.query(UploadJobs).filter_by(multifile=False).with_for_update().all()
        )

    return session.query(ManifestUploads).with_for_update().all()


def _run_metadata_migration(
    metadata_migration: Callable[[dict], MigrationResult],
    use_upload_jobs_table: bool,
    gcs_tasks: RollbackableQueue,
    session: Session,
):
    # Migrate all trial records
    trials = _select_trials(session)
    for trial in trials:
        print(f"Running metadata migration for trial: {trial.trial_id}")
        migration = metadata_migration(trial.metadata_json)

        # Update the trial metadata object
        trial.safely_set_metadata_json(migration.result)

        # A workaround fix for JSON field modifications not being tracked
        # by SQLalchemy for some reason. Using MutableDict.as_mutable(JSON)
        # in the model doesn't seem to help.
        flag_modified(trial, "metadata_json")

        # If this trial has no file updates, move on to the next one
        if len(migration.file_updates) == 0:
            continue

        # Update the relevant downloadable files and GCS objects
        uuid_path_map = _get_uuid_path_map(migration.result)
        for old_gcs_uri, artifact in migration.file_updates.items():
            print(f"Updating GCS and artifact info for {old_gcs_uri}: {artifact}")
            # Update the downloadable file associated with this blob
            df = DownloadableFiles.get_by_object_url(old_gcs_uri, session=session)
            for column, value in artifact.items():
                if hasattr(df, column):
                    setattr(df, column, value)

            # Regenerate additional metadata from the migrated clinical trial
            # metadata object.
            print(
                f"Regenerating additional metadata for artifact with uuid {artifact['upload_placeholder']}"
            )
            artifact_path = uuid_path_map[artifact["upload_placeholder"]]
            df.additional_metadata = get_source(
                migration.result, artifact_path, skip_last=True
            )[1]

            # If the GCS URI has changed, rename the blob
            # makes call to bucket.rename_blob 
            new_gcs_uri = artifact["object_url"]
            if old_gcs_uri != new_gcs_uri:
                print(
                    f"Encountered GCS data bucket artifact URI to update: {old_gcs_uri}"
                )
                renamer = PieceOfWork(
                    partial(
                        rename_gcs_blob, GOOGLE_ACL_DATA_BUCKET, old_gcs_uri, new_gcs_uri
                    ),
                    partial(
                        rename_gcs_blob, GOOGLE_ACL_DATA_BUCKET, new_gcs_uri, old_gcs_uri
                    ),
                )
                gcs_tasks.schedule(renamer)

    # Migrate all assay upload successes
    successful_assay_uploads = _select_successful_assay_uploads(
        use_upload_jobs_table, session
    )
    for upload in successful_assay_uploads:
        print(f"Running metadata migration for assay upload: {upload.id}")
        if use_upload_jobs_table:
            migration = metadata_migration(upload.metadata_patch)
            upload.metadata_patch = migration.result
            # A workaround fix for JSON field modifications not being tracked
            # by SQLalchemy for some reason. Using MutableDict.as_mutable(JSON)
            # in the model doesn't seem to help.
            flag_modified(upload, "metadata_patch")
        else:
            migration = metadata_migration(upload.assay_patch)
            upload.assay_patch = migration.result
            flag_modified(upload, "assay_patch")

        # Update the GCS URIs of files that were part of this upload
        new_file_map = {}
        for (
            old_upload_uri,
            old_target_uri,
            artifact_uuid,
        ) in upload.upload_uris_with_data_uris_with_uuids():
            upload_timestamp = old_upload_uri[len(old_target_uri) + 1 :]
            if old_target_uri in migration.file_updates:
                new_target_uri = migration.file_updates[old_target_uri]["object_url"]
                if old_target_uri != new_target_uri:
                    print(
                        f"Encountered GCS upload bucket artifact URI to update: {old_upload_uri}"
                    )
                    new_upload_uri = "/".join([new_target_uri, upload_timestamp])
                    renamer = PieceOfWork(
                        partial(
                            rename_gcs_blob,
                            GOOGLE_UPLOAD_BUCKET,
                            old_upload_uri,
                            new_upload_uri,
                        ),
                        partial(
                            rename_gcs_blob,
                            GOOGLE_UPLOAD_BUCKET,
                            new_upload_uri,
                            old_upload_uri,
                        ),
                    )
                    gcs_tasks.schedule(renamer)
                new_file_map[new_upload_uri] = artifact_uuid

        # Update the upload's file map to use new GCS URIs
        upload.gcs_file_map = new_file_map

    # Migrate all manifest records
    manifest_uploads = _select_manifest_uploads(use_upload_jobs_table, session)
    for upload in manifest_uploads:
        print(f"Running metadata migration for manifest upload: {upload.id}")
        migration = metadata_migration(upload.metadata_patch)

        # Update the metadata patch
        upload.metadata_patch = migration.result

        # A workaround fix for JSON field modifications not being tracked
        # by SQLalchemy for some reason. Using MutableDict.as_mutable(JSON)
        # in the model doesn't seem to help.
        flag_modified(upload, "metadata_patch")

    # Attempt to make GCS updates
    print(f"Running all GCS tasks...")
    gcs_tasks.run_all()
    print(f"GCS tasks succeeded.")


dont_run = os.environ.get("TESTING") or os.environ.get("ENV") == "dev"


def rename_gcs_blob(bucket, old_name, new_name):
    full_old_uri = f"gs://{bucket}/{old_name}"
    full_new_uri = f"gs://{bucket}/{new_name}"
    message = f"GCS: moving {full_old_uri} to {full_new_uri}"
    if dont_run:
        print(f"SKIPPING: {message}")
        return

    print(message)

    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket)
    old_blob = bucket.blob(old_name)
    new_blob = bucket.rename_blob(old_blob, new_name)
    return new_blob


def republish_artifact_uploads():
    """
    Publish all downloadable_file IDs to the `artifact_upload` Pub/Sub topic,
    triggering downstream file post-processing (e.g., pre-computation for visualization
    purposes).
    """
    if dont_run:
        print("Skipping 'republish_artifact_uploads' because this is a test")
        return

    with migration_session() as (session, _):
        files = session.query(DownloadableFiles).all()
        for f in files:
            print(
                f"Publishing to 'artifact_upload' topic for downloadable file with in bucket url {f.object_url}"
            )
            publish_artifact_upload(f.object_url)