airbnb/caravel

View on GitHub
superset/daos/query.py

Summary

Maintainability
A
0 mins
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from typing import Any, Union

from superset import sql_lab
from superset.common.db_query_status import QueryStatus
from superset.daos.base import BaseDAO
from superset.exceptions import QueryNotFoundException, SupersetCancelQueryException
from superset.extensions import db
from superset.models.sql_lab import Query, SavedQuery
from superset.queries.filters import QueryFilter
from superset.queries.saved_queries.filters import SavedQueryFilter
from superset.utils.core import get_user_id
from superset.utils.dates import now_as_float

logger = logging.getLogger(__name__)


class QueryDAO(BaseDAO[Query]):
    base_filter = QueryFilter

    @staticmethod
    def update_saved_query_exec_info(query_id: int) -> None:
        """
        Propagates query execution info back to saved query if applicable

        :param query_id: The query id
        :return:
        """
        query = db.session.query(Query).get(query_id)
        related_saved_queries = (
            db.session.query(SavedQuery)
            .filter(SavedQuery.database == query.database)
            .filter(SavedQuery.sql == query.sql)
        ).all()
        if related_saved_queries:
            for saved_query in related_saved_queries:
                saved_query.rows = query.rows
                saved_query.last_run = datetime.now()

    @staticmethod
    def save_metadata(query: Query, payload: dict[str, Any]) -> None:
        # pull relevant data from payload and store in extra_json
        columns = payload.get("columns", {})
        for col in columns:
            if "name" in col:
                col["column_name"] = col.get("name")
        db.session.add(query)
        query.set_extra_json_key("columns", columns)

    @staticmethod
    def get_queries_changed_after(last_updated_ms: Union[float, int]) -> list[Query]:
        # UTC date time, same that is stored in the DB.
        last_updated_dt = datetime.utcfromtimestamp(last_updated_ms / 1000)

        return (
            db.session.query(Query)
            .filter(Query.user_id == get_user_id(), Query.changed_on >= last_updated_dt)
            .all()
        )

    @staticmethod
    def stop_query(client_id: str) -> None:
        query = db.session.query(Query).filter_by(client_id=client_id).one_or_none()
        if not query:
            raise QueryNotFoundException(f"Query with client_id {client_id} not found")

        if query.status in [
            QueryStatus.FAILED,
            QueryStatus.SUCCESS,
            QueryStatus.TIMED_OUT,
        ]:
            logger.warning(
                "Query with client_id could not be stopped: query already complete",
            )
            return

        if not sql_lab.cancel_query(query):
            raise SupersetCancelQueryException("Could not cancel query")

        query.status = QueryStatus.STOPPED
        query.end_time = now_as_float()


class SavedQueryDAO(BaseDAO[SavedQuery]):
    base_filter = SavedQueryFilter