airbnb/caravel

View on GitHub
superset/daos/report.py

Summary

Maintainability
A
0 mins
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any

from superset.daos.base import BaseDAO
from superset.extensions import db
from superset.reports.filters import ReportScheduleFilter
from superset.reports.models import (
    ReportExecutionLog,
    ReportRecipients,
    ReportSchedule,
    ReportScheduleType,
    ReportState,
)
from superset.utils import json
from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)


REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"


class ReportScheduleDAO(BaseDAO[ReportSchedule]):
    base_filter = ReportScheduleFilter

    @staticmethod
    def find_by_chart_id(chart_id: int) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.chart_id == chart_id)
            .all()
        )

    @staticmethod
    def find_by_chart_ids(chart_ids: list[int]) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.chart_id.in_(chart_ids))
            .all()
        )

    @staticmethod
    def find_by_dashboard_id(dashboard_id: int) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.dashboard_id == dashboard_id)
            .all()
        )

    @staticmethod
    def find_by_dashboard_ids(dashboard_ids: list[int]) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.dashboard_id.in_(dashboard_ids))
            .all()
        )

    @staticmethod
    def find_by_database_id(database_id: int) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.database_id == database_id)
            .all()
        )

    @staticmethod
    def find_by_database_ids(database_ids: list[int]) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.database_id.in_(database_ids))
            .all()
        )

    @staticmethod
    def find_by_extra_metadata(slug: str) -> list[ReportSchedule]:
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.extra_json.like(f"%{slug}%"))
            .all()
        )

    @staticmethod
    def validate_unique_creation_method(
        dashboard_id: int | None = None, chart_id: int | None = None
    ) -> bool:
        """
        Validate if the user already has a chart or dashboard
        with a report attached form the self subscribe reports
        """

        query = db.session.query(ReportSchedule).filter_by(created_by_fk=get_user_id())
        if dashboard_id is not None:
            query = query.filter(ReportSchedule.dashboard_id == dashboard_id)

        if chart_id is not None:
            query = query.filter(ReportSchedule.chart_id == chart_id)

        return not db.session.query(query.exists()).scalar()

    @staticmethod
    def validate_update_uniqueness(
        name: str, report_type: ReportScheduleType, expect_id: int | None = None
    ) -> bool:
        """
        Validate if this name and type is unique.

        :param name: The report schedule name
        :param report_type: The report schedule type
        :param expect_id: The id of the expected report schedule with the
          name + type combination. Useful for validating existing report schedule.
        :return: bool
        """
        found_id = (
            db.session.query(ReportSchedule.id)
            .filter(ReportSchedule.name == name, ReportSchedule.type == report_type)
            .limit(1)
            .scalar()
        )
        return found_id is None or found_id == expect_id

    @classmethod
    def create(
        cls,
        item: ReportSchedule | None = None,
        attributes: dict[str, Any] | None = None,
    ) -> ReportSchedule:
        """
        Create a report schedule with nested recipients.

        :param item: The object to create
        :param attributes: The attributes associated with the object to create
        """

        # TODO(john-bodley): Determine why we need special handling for recipients.
        if not item:
            item = ReportSchedule()

        if attributes:
            if recipients := attributes.pop("recipients", None):
                attributes["recipients"] = [
                    ReportRecipients(
                        type=recipient["type"],
                        recipient_config_json=json.dumps(
                            recipient["recipient_config_json"]
                        ),
                        report_schedule=item,
                    )
                    for recipient in recipients
                ]

        return super().create(item, attributes)

    @classmethod
    def update(
        cls,
        item: ReportSchedule | None = None,
        attributes: dict[str, Any] | None = None,
    ) -> ReportSchedule:
        """
        Update a report schedule with nested recipients.

        :param item: The object to update
        :param attributes: The attributes associated with the object to update
        """

        # TODO(john-bodley): Determine why we need special handling for recipients.
        if not item:
            item = ReportSchedule()

        if attributes:
            if recipients := attributes.pop("recipients", None):
                attributes["recipients"] = [
                    ReportRecipients(
                        type=recipient["type"],
                        recipient_config_json=json.dumps(
                            recipient["recipient_config_json"]
                        ),
                        report_schedule=item,
                    )
                    for recipient in recipients
                ]

        return super().update(item, attributes)

    @staticmethod
    def find_active() -> list[ReportSchedule]:
        """
        Find all active reports.
        """
        return (
            db.session.query(ReportSchedule)
            .filter(ReportSchedule.active.is_(True))
            .all()
        )

    @staticmethod
    def find_last_success_log(
        report_schedule: ReportSchedule,
    ) -> ReportExecutionLog | None:
        """
        Finds last success execution log for a given report
        """
        return (
            db.session.query(ReportExecutionLog)
            .filter(
                ReportExecutionLog.state == ReportState.SUCCESS,
                ReportExecutionLog.report_schedule == report_schedule,
            )
            .order_by(ReportExecutionLog.end_dttm.desc())
            .first()
        )

    @staticmethod
    def find_last_entered_working_log(
        report_schedule: ReportSchedule,
    ) -> ReportExecutionLog | None:
        """
        Finds last success execution log for a given report
        """
        return (
            db.session.query(ReportExecutionLog)
            .filter(
                ReportExecutionLog.state == ReportState.WORKING,
                ReportExecutionLog.report_schedule == report_schedule,
                ReportExecutionLog.error_message.is_(None),
            )
            .order_by(ReportExecutionLog.end_dttm.desc())
            .first()
        )

    @staticmethod
    def find_last_error_notification(
        report_schedule: ReportSchedule,
    ) -> ReportExecutionLog | None:
        """
        Finds last error email sent
        """
        last_error_email_log = (
            db.session.query(ReportExecutionLog)
            .filter(
                ReportExecutionLog.error_message
                == REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER,
                ReportExecutionLog.report_schedule == report_schedule,
            )
            .order_by(ReportExecutionLog.end_dttm.desc())
            .first()
        )
        if not last_error_email_log:
            return None
        # Checks that only errors have occurred since the last email
        report_from_last_email = (
            db.session.query(ReportExecutionLog)
            .filter(
                ReportExecutionLog.state.notin_(
                    [ReportState.ERROR, ReportState.WORKING]
                ),
                ReportExecutionLog.report_schedule == report_schedule,
                ReportExecutionLog.end_dttm < last_error_email_log.end_dttm,
            )
            .order_by(ReportExecutionLog.end_dttm.desc())
            .first()
        )
        return last_error_email_log if not report_from_last_email else None

    @staticmethod
    def bulk_delete_logs(model: ReportSchedule, from_date: datetime) -> int | None:
        return (
            db.session.query(ReportExecutionLog)
            .filter(
                ReportExecutionLog.report_schedule == model,
                ReportExecutionLog.end_dttm < from_date,
            )
            .delete(synchronize_session="fetch")
        )