airbnb/caravel

View on GitHub
superset/extensions/metadb.py

Summary

Maintainability
C
1 day
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A SQLAlchemy dialect for querying across Superset databases.

The dialect ``superset://`` allows users to query any table in any database that has been
configured in Superset, eg:

    > SELECT * FROM "examples.birth_names";

The syntax for tables is:

    database[[.catalog].schema].table

The dialect is built on top of Shillelagh, a framework for building DB API 2.0 libraries
and SQLAlchemy dialects based on SQLite. SQLite will parse the SQL, and pass the filters
to the adapter. The adapter builds a SQLAlchemy query object reading data from the table
and applying any filters (as well as sorting, limiting, and offsetting).

Note that no aggregation is done on the database. Aggregations and other operations like
joins and unions are done in memory, using the SQLite engine.
"""

from __future__ import annotations

import datetime
import decimal
import operator
import urllib.parse
from collections.abc import Iterator
from functools import partial, wraps
from typing import Any, Callable, cast, TypeVar

from flask import current_app
from shillelagh.adapters.base import Adapter
from shillelagh.backends.apsw.dialects.base import APSWDialect
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import (
    Boolean,
    Date,
    DateTime,
    Field,
    Float,
    Integer,
    Order,
    String,
    Time,
)
from shillelagh.filters import Equal, Filter, Range
from shillelagh.typing import RequestedOrder, Row
from sqlalchemy import func, MetaData, Table
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import Select, select

from superset import db, feature_flag_manager, security_manager, sql_parse


# pylint: disable=abstract-method
class SupersetAPSWDialect(APSWDialect):
    """
    A SQLAlchemy dialect for an internal Superset engine.

    This dialect allows query to be executed across different Superset
    databases. For example, to read data from the `birth_names` table in the
    `examples` databases:

        >>> engine = create_engine('superset://')
        >>> conn = engine.connect()
        >>> results = conn.execute('SELECT * FROM "examples.birth_names"')

    Queries can also join data across different Superset databases.

    The dialect is built in top of the Shillelagh library, leveraging SQLite to
    create virtual tables on-the-fly proxying Superset tables. The
    `SupersetShillelaghAdapter` adapter is responsible for returning data when a
    Superset table is accessed.
    """

    name = "superset"

    def __init__(self, allowed_dbs: list[str] | None = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)

        self.allowed_dbs = allowed_dbs

    def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]:
        """
        A custom Shillelagh SQLAlchemy dialect with a single adapter configured.
        """
        return (
            (),
            {
                "path": ":memory:",
                "adapters": ["superset"],
                "adapter_kwargs": {
                    "superset": {
                        "prefix": None,
                        "allowed_dbs": self.allowed_dbs,
                    }
                },
                "safe": True,
                "isolation_level": self.isolation_level,
            },
        )


F = TypeVar("F", bound=Callable[..., Any])


def check_dml(method: F) -> F:
    """
    Decorator that prevents DML against databases where it's not allowed.
    """

    @wraps(method)
    def wrapper(self: SupersetShillelaghAdapter, *args: Any, **kwargs: Any) -> Any:
        # pylint: disable=protected-access
        if not self._allow_dml:
            raise ProgrammingError(f'DML not enabled in database "{self.database}"')
        return method(self, *args, **kwargs)

    return cast(F, wrapper)


def has_rowid(method: F) -> F:
    """
    Decorator that prevents updates/deletes on tables without a rowid.
    """

    @wraps(method)
    def wrapper(self: SupersetShillelaghAdapter, *args: Any, **kwargs: Any) -> Any:
        # pylint: disable=protected-access
        if not self._rowid:
            raise ProgrammingError(
                "Can only modify data in a table with a single, integer, primary key"
            )
        return method(self, *args, **kwargs)

    return cast(F, wrapper)


class Duration(Field[datetime.timedelta, datetime.timedelta]):
    """
    Shillelagh field used for representing durations as `timedelta` objects.
    """

    type = "DURATION"
    db_api_type = "DATETIME"


class Decimal(Field[decimal.Decimal, decimal.Decimal]):
    """
    Shillelagh field used for representing decimals.
    """

    type = "DECIMAL"
    db_api_type = "NUMBER"


class FallbackField(Field[Any, str]):
    """
    Fallback field for unknown types; converts to string.
    """

    type = "TEXT"
    db_api_type = "STRING"

    def parse(self, value: Any) -> str | None:
        return value if value is None else str(value)


# pylint: disable=too-many-instance-attributes
class SupersetShillelaghAdapter(Adapter):
    """
    A Shillelagh adapter for Superset tables.

    Shillelagh adapters are responsible for fetching data from a given resource,
    allowing it to be represented as a virtual table in SQLite. This one works
    as a proxy to Superset tables.
    """

    # no access to the filesystem
    safe = True

    supports_limit = True
    supports_offset = True

    type_map: dict[Any, type[Field]] = {
        bool: Boolean,
        float: Float,
        int: Integer,
        str: String,
        datetime.date: Date,
        datetime.datetime: DateTime,
        datetime.time: Time,
        datetime.timedelta: Duration,
        decimal.Decimal: Decimal,
    }

    @staticmethod
    def supports(
        uri: str,
        fast: bool = True,
        prefix: str | None = "superset",
        allowed_dbs: list[str] | None = None,
        **kwargs: Any,
    ) -> bool:
        """
        Return if a table is supported by the adapter.

        An URL for a table has the format [prefix.]database[[.catalog].schema].table,
        eg, `superset.examples.birth_names`.

        When using the Superset SQLAlchemy and DB engine spec the prefix is dropped, so
        that tables should have the format database[[.catalog].schema].table.
        """
        parts = [urllib.parse.unquote(part) for part in uri.split(".")]

        if prefix is not None:
            if parts.pop(0) != prefix:
                return False

        if allowed_dbs is not None and parts[0] not in allowed_dbs:
            return False

        return 2 <= len(parts) <= 4

    @staticmethod
    def parse_uri(uri: str) -> tuple[str]:
        """
        Pass URI through unmodified.
        """
        return (uri,)

    def __init__(
        self,
        uri: str,
        prefix: str | None = "superset",
        **kwargs: Any,
    ):
        if not feature_flag_manager.is_feature_enabled("ENABLE_SUPERSET_META_DB"):
            raise ProgrammingError("Superset meta database is disabled")

        super().__init__(**kwargs)

        parts = [urllib.parse.unquote(part) for part in uri.split(".")]

        if prefix is not None:
            if prefix != parts.pop(0):
                raise ProgrammingError("Invalid prefix")
            self.prefix = prefix

        self.database = parts.pop(0)
        self.table = parts.pop(-1)
        self.schema = parts.pop(-1) if parts else None
        self.catalog = parts.pop(-1) if parts else None

        # If the table has a single integer primary key we use that as the row ID in order
        # to perform updates and deletes. Otherwise we can only do inserts and selects.
        self._rowid: str | None = None

        # Does the database allow DML?
        self._allow_dml: bool = False

        # Read column information from the database, and store it for later.
        self._set_columns()

    @classmethod
    def get_field(cls, python_type: Any) -> Field:
        """
        Convert a Python type into a Shillelagh field.
        """
        class_ = cls.type_map.get(python_type, FallbackField)
        return class_(filters=[Equal, Range], order=Order.ANY, exact=True)

    def _set_columns(self) -> None:
        """
        Inspect the table and get its columns.

        This is done on initialization because it's expensive.
        """
        # pylint: disable=import-outside-toplevel
        from superset.models.core import Database

        database = (
            db.session.query(Database).filter_by(database_name=self.database).first()
        )
        if database is None:
            raise ProgrammingError(f"Database not found: {self.database}")
        self._allow_dml = database.allow_dml

        # verify permissions
        table = sql_parse.Table(self.table, self.schema, self.catalog)
        security_manager.raise_for_access(database=database, table=table)

        # store this callable for later whenever we need an engine
        self.engine_context = partial(
            database.get_sqla_engine,
            catalog=self.catalog,
            schema=self.schema,
        )

        # fetch column names and types
        metadata = MetaData()
        with self.engine_context() as engine:
            try:
                self._table = Table(
                    self.table,
                    metadata,
                    schema=self.schema,
                    autoload=True,
                    autoload_with=engine,
                )
            except NoSuchTableError as ex:
                raise ProgrammingError(f"Table does not exist: {self.table}") from ex

        # find row ID column; we can only update/delete data into a table with a
        # single integer primary key
        primary_keys = [
            column for column in list(self._table.primary_key) if column.primary_key
        ]
        if len(primary_keys) == 1 and primary_keys[0].type.python_type == int:
            self._rowid = primary_keys[0].name

        self.columns = {
            column.name: self.get_field(column.type.python_type)
            for column in self._table.c
        }

    def get_columns(self) -> dict[str, Field]:
        """
        Return table columns.
        """
        return self.columns

    def _build_sql(
        self,
        bounds: dict[str, Filter],
        order: list[tuple[str, RequestedOrder]],
        limit: int | None = None,
        offset: int | None = None,
    ) -> Select:
        """
        Build SQLAlchemy query object.
        """
        query = select([self._table])

        for column_name, filter_ in bounds.items():
            column = self._table.c[column_name]
            if isinstance(filter_, Equal):
                query = query.where(column == filter_.value)
            elif isinstance(filter_, Range):
                if filter_.start is not None:
                    op = operator.ge if filter_.include_start else operator.gt
                    query = query.where(op(column, filter_.start))
                if filter_.end is not None:
                    op = operator.le if filter_.include_end else operator.lt
                    query = query.where(op(column, filter_.end))
            else:
                raise ProgrammingError(f"Invalid filter: {filter_}")

        for column_name, requested_order in order:
            column = self._table.c[column_name]
            if requested_order == Order.DESCENDING:
                column = column.desc()
            query = query.order_by(column)

        if limit is not None:
            query = query.limit(limit)
        if offset is not None:
            query = query.offset(offset)

        return query

    def get_data(
        self,
        bounds: dict[str, Filter],
        order: list[tuple[str, RequestedOrder]],
        limit: int | None = None,
        offset: int | None = None,
        **kwargs: Any,
    ) -> Iterator[Row]:
        """
        Return data for a `SELECT` statement.
        """
        app_limit: int | None = current_app.config["SUPERSET_META_DB_LIMIT"]
        if limit is None:
            limit = app_limit
        elif app_limit is not None:
            limit = min(limit, app_limit)

        query = self._build_sql(bounds, order, limit, offset)

        with self.engine_context() as engine:
            connection = engine.connect()
            rows = connection.execute(query)
            for i, row in enumerate(rows):
                data = dict(zip(self.columns, row))
                data["rowid"] = data[self._rowid] if self._rowid else i
                yield data

    @check_dml
    def insert_row(self, row: Row) -> int:
        """
        Insert a single row.
        """
        row_id: int | None = row.pop("rowid")
        if row_id and self._rowid:
            if row.get(self._rowid) != row_id:
                raise ProgrammingError(f"Invalid rowid specified: {row_id}")
            row[self._rowid] = row_id

        if (
            self._rowid
            and row[self._rowid] is None
            and self._table.c[self._rowid].autoincrement
        ):
            row.pop(self._rowid)

        query = self._table.insert().values(**row)

        with self.engine_context() as engine:
            connection = engine.connect()
            result = connection.execute(query)

            # return rowid
            if self._rowid:
                return result.inserted_primary_key[0]

            query = select([func.count()]).select_from(self._table)
            return connection.execute(query).scalar()

    @check_dml
    @has_rowid
    def delete_row(self, row_id: int) -> None:
        """
        Delete a single row given its row ID.
        """
        query = self._table.delete().where(self._table.c[self._rowid] == row_id)

        with self.engine_context() as engine:
            connection = engine.connect()
            connection.execute(query)

    @check_dml
    @has_rowid
    def update_row(self, row_id: int, row: Row) -> None:
        """
        Update a single row given its row ID.

        Note that the updated row might have a new row ID.
        """
        new_row_id: int | None = row.pop("rowid")
        if new_row_id:
            if row.get(self._rowid) != new_row_id:
                raise ProgrammingError(f"Invalid rowid specified: {new_row_id}")
            row[self._rowid] = new_row_id

        query = (
            self._table.update()
            .where(self._table.c[self._rowid] == row_id)
            .values(**row)
        )

        with self.engine_context() as engine:
            connection = engine.connect()
            connection.execute(query)