airbnb/caravel

View on GitHub
superset/utils/decorators.py

Summary

Maintainability
A
3 hrs
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import time
from collections.abc import Iterator
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, TYPE_CHECKING
from uuid import UUID

from flask import current_app, g, Response
from sqlalchemy.exc import SQLAlchemyError

from superset.utils import core as utils
from superset.utils.dates import now_as_float

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from superset.stats_logger import BaseStatsLogger


def statsd_gauge(metric_prefix: str | None = None) -> Callable[..., Any]:
    def decorate(f: Callable[..., Any]) -> Callable[..., Any]:
        """
        Handle sending statsd gauge metric from any method or function
        """

        def wrapped(*args: Any, **kwargs: Any) -> Any:
            metric_prefix_ = metric_prefix or f.__name__
            try:
                result = f(*args, **kwargs)
                current_app.config["STATS_LOGGER"].gauge(f"{metric_prefix_}.ok", 1)
                return result
            except Exception as ex:
                if (
                    hasattr(ex, "status") and ex.status < 500  # pylint: disable=no-member
                ):
                    current_app.config["STATS_LOGGER"].gauge(
                        f"{metric_prefix_}.warning", 1
                    )
                else:
                    current_app.config["STATS_LOGGER"].gauge(
                        f"{metric_prefix_}.error", 1
                    )
                raise

        return wrapped

    return decorate


def logs_context(
    context_func: Callable[..., dict[Any, Any]] | None = None,
    **ctx_kwargs: int | str | UUID | None,
) -> Callable[..., Any]:
    """
    Takes arguments and adds them to the global logs_context.
    This is for logging purposes only and values should not be relied on or mutated
    """

    def decorate(f: Callable[..., Any]) -> Callable[..., Any]:
        def wrapped(*args: Any, **kwargs: Any) -> Any:
            if not hasattr(g, "logs_context"):
                g.logs_context = {}

            # limit data that can be saved to logs_context
            # in order to prevent antipatterns
            available_logs_context_keys = [
                "slice_id",
                "dashboard_id",
                "dataset_id",
                "execution_id",
                "report_schedule_id",
            ]
            # set value from kwargs from
            # wrapper function if it exists
            # e.g. @logs_context()
            #      def my_func(slice_id=None, **kwargs)
            #
            #      my_func(slice_id=2)
            logs_context_data = {
                key: val
                for key, val in kwargs.items()
                if key in available_logs_context_keys
                if val is not None
            }

            try:
                # if keys are passed in to decorator directly, add them to logs_context
                # by overriding values from kwargs
                # e.g. @logs_context(slice_id=1, dashboard_id=1)
                logs_context_data.update(
                    {
                        key: ctx_kwargs.get(key)
                        for key in available_logs_context_keys
                        if ctx_kwargs.get(key) is not None
                    }
                )

                if context_func is not None:
                    # if a context function is passed in, call it and add the
                    # returned values to logs_context
                    # context_func=lambda *args, **kwargs: {
                    # "slice_id": 1, "dashboard_id": 1
                    # }
                    logs_context_data.update(
                        {
                            key: value
                            for key, value in context_func(*args, **kwargs).items()
                            if key in available_logs_context_keys
                            if value is not None
                        }
                    )

            except (TypeError, KeyError, AttributeError):
                # do nothing if the key doesn't exist
                # or context is not callable
                logger.warning("Invalid data was passed to the logs context decorator")

            g.logs_context.update(logs_context_data)
            return f(*args, **kwargs)

        return wrapped

    return decorate


@contextmanager
def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
    """Provide a transactional scope around a series of operations."""
    start_ts = now_as_float()
    try:
        yield start_ts
    finally:
        stats_logger.timing(stats_key, now_as_float() - start_ts)


def arghash(args: Any, kwargs: Any) -> int:
    """Simple argument hash with kwargs sorted."""
    sorted_args = tuple(
        x if hasattr(x, "__repr__") else x for x in [*args, *sorted(kwargs.items())]
    )
    return hash(sorted_args)


def debounce(duration: float | int = 0.1) -> Callable[..., Any]:
    """Ensure a function called with the same arguments executes only once
    per `duration` (default: 100ms).
    """

    def decorate(f: Callable[..., Any]) -> Callable[..., Any]:
        last: dict[str, Any] = {"t": None, "input": None, "output": None}

        def wrapped(*args: Any, **kwargs: Any) -> Any:
            now = time.time()
            updated_hash = arghash(args, kwargs)
            if (
                last["t"] is None
                or now - last["t"] >= duration
                or last["input"] != updated_hash
            ):
                result = f(*args, **kwargs)
                last["t"] = time.time()
                last["input"] = updated_hash
                last["output"] = result
                return result
            return last["output"]

        return wrapped

    return decorate


def on_security_exception(self: Any, ex: Exception) -> Response:
    return self.response(403, **{"message": utils.error_msg_from_exception(ex)})


@contextmanager
def suppress_logging(
    logger_name: str | None = None,
    new_level: int = logging.CRITICAL,
) -> Iterator[None]:
    """
    Context manager to suppress logging during the execution of code block.

    Use with caution and make sure you have the least amount of code inside it.
    """
    target_logger = logging.getLogger(logger_name)
    original_level = target_logger.getEffectiveLevel()
    target_logger.setLevel(new_level)
    try:
        yield
    finally:
        target_logger.setLevel(original_level)


def on_error(
    ex: Exception,
    catches: tuple[type[Exception], ...] = (SQLAlchemyError,),
    reraise: type[Exception] | None = SQLAlchemyError,
) -> None:
    """
    Default error handler whenever any exception is caught during a SQLAlchemy nested
    transaction.

    :param ex: The source exception
    :param catches: The exception types the handler catches
    :param reraise: The exception type the handler raises after catching
    :raises Exception: If the exception is not swallowed
    """

    if isinstance(ex, catches):
        if hasattr(ex, "exception"):
            logger.exception(ex.exception)

        if reraise:
            raise reraise() from ex
    else:
        raise ex


def transaction(  # pylint: disable=redefined-outer-name
    on_error: Callable[..., Any] | None = on_error,
) -> Callable[..., Any]:
    """
    Perform a "unit of work".

    Note ideally this would leverage SQLAlchemy's nested transaction, however this
    proved rather complicated, likely due to many architectural facets, and thus has
    been left for a follow up exercise.

    :param on_error: Callback invoked when an exception is caught
    :see: https://github.com/apache/superset/issues/25108
    """

    def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
        @wraps(func)
        def wrapped(*args: Any, **kwargs: Any) -> Any:
            from superset import db  # pylint: disable=import-outside-toplevel

            try:
                result = func(*args, **kwargs)
                db.session.commit()  # pylint: disable=consider-using-transaction
                return result
            except Exception as ex:
                db.session.rollback()  # pylint: disable=consider-using-transaction

                if on_error:
                    return on_error(ex)

                raise

        return wrapped

    return decorate