airbnb/superset

View on GitHub
superset/connectors/sqla/models.py

Summary

Maintainability
F
1 wk
Test Coverage
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations

import builtins
import dataclasses
import json
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from json.decoder import JSONDecodeError
from typing import Any, Callable, cast

import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask_appbuilder import Model
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from jinja2.exceptions import TemplateError
from markupsafe import escape, Markup
from sqlalchemy import (
    and_,
    Boolean,
    Column,
    DateTime,
    Enum,
    ForeignKey,
    inspect,
    Integer,
    or_,
    String,
    Table as DBTable,
    Text,
    update,
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
    backref,
    foreign,
    Mapped,
    Query,
    reconstructor,
    relationship,
    RelationshipProperty,
)
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause

from superset import app, db, is_feature_enabled, security_manager
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.utils import (
    get_columns_description,
    get_physical_table_metadata,
    get_virtual_table_metadata,
)
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
    ColumnNotFoundException,
    DatasetInvalidPermissionEvaluationException,
    QueryClauseValidationException,
    QueryObjectValidationError,
    SupersetErrorException,
    SupersetErrorsException,
    SupersetGenericDBErrorException,
    SupersetSecurityException,
)
from superset.jinja_context import (
    BaseTemplateProcessor,
    ExtraCache,
    get_template_processor,
)
from superset.models.annotations import Annotation
from superset.models.core import Database
from superset.models.helpers import (
    AuditMixinNullable,
    CertificationMixin,
    ExploreMixin,
    ImportExportMixin,
    QueryResult,
    validate_adhoc_subquery,
)
from superset.models.slice import Slice
from superset.sql_parse import ParsedQuery, sanitize_clause, Table
from superset.superset_typing import (
    AdhocColumn,
    AdhocMetric,
    FilterValue,
    FilterValues,
    Metric,
    QueryObjectDict,
    ResultSetColumnType,
)
from superset.utils import core as utils
from superset.utils.backports import StrEnum
from superset.utils.core import GenericDataType, MediumText

config = app.config
metadata = Model.metadata  # pylint: disable=no-member
logger = logging.getLogger(__name__)
ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
VIRTUAL_TABLE_ALIAS = "virtual_table"

# a non-exhaustive set of additive metrics
ADDITIVE_METRIC_TYPES = {
    "count",
    "sum",
    "doubleSum",
}
ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}


@dataclass
class MetadataResult:
    added: list[str] = field(default_factory=list)
    removed: list[str] = field(default_factory=list)
    modified: list[str] = field(default_factory=list)


logger = logging.getLogger(__name__)

METRIC_FORM_DATA_PARAMS = [
    "metric",
    "metric_2",
    "metrics",
    "metrics_b",
    "percent_metrics",
    "secondary_metric",
    "size",
    "timeseries_limit_metric",
    "x",
    "y",
]

COLUMN_FORM_DATA_PARAMS = [
    "all_columns",
    "all_columns_x",
    "columns",
    "entity",
    "groupby",
    "order_by_cols",
    "series",
]


class DatasourceKind(StrEnum):
    VIRTUAL = "virtual"
    PHYSICAL = "physical"


class BaseDatasource(AuditMixinNullable, ImportExportMixin):  # pylint: disable=too-many-public-methods
    """A common interface to objects that are queryable
    (tables and datasources)"""

    # ---------------------------------------------------------------
    # class attributes to define when deriving BaseDatasource
    # ---------------------------------------------------------------
    __tablename__: str | None = None  # {connector_name}_datasource
    baselink: str | None = None  # url portion pointing to ModelView endpoint

    owner_class: User | None = None

    # Used to do code highlighting when displaying the query in the UI
    query_language: str | None = None

    # Only some datasources support Row Level Security
    is_rls_supported: bool = False

    @property
    def name(self) -> str:
        # can be a Column or a property pointing to one
        raise NotImplementedError()

    # ---------------------------------------------------------------

    # Columns
    id = Column(Integer, primary_key=True)
    description = Column(Text)
    default_endpoint = Column(Text)
    is_featured = Column(Boolean, default=False)  # TODO deprecating
    filter_select_enabled = Column(Boolean, default=True)
    offset = Column(Integer, default=0)
    cache_timeout = Column(Integer)
    params = Column(String(1000))
    perm = Column(String(1000))
    schema_perm = Column(String(1000))
    is_managed_externally = Column(Boolean, nullable=False, default=False)
    external_url = Column(Text, nullable=True)

    sql: str | None = None
    owners: list[User]
    update_from_object_fields: list[str]

    extra_import_fields = ["is_managed_externally", "external_url"]

    @property
    def kind(self) -> DatasourceKind:
        return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL

    @property
    def owners_data(self) -> list[dict[str, Any]]:
        return [
            {
                "first_name": o.first_name,
                "last_name": o.last_name,
                "username": o.username,
                "id": o.id,
            }
            for o in self.owners
        ]

    @property
    def is_virtual(self) -> bool:
        return self.kind == DatasourceKind.VIRTUAL

    @declared_attr
    def slices(self) -> RelationshipProperty:
        return relationship(
            "Slice",
            overlaps="table",
            primaryjoin=lambda: and_(
                foreign(Slice.datasource_id) == self.id,
                foreign(Slice.datasource_type) == self.type,
            ),
        )

    columns: list[TableColumn] = []
    metrics: list[SqlMetric] = []

    @property
    def type(self) -> str:
        raise NotImplementedError()

    @property
    def uid(self) -> str:
        """Unique id across datasource types"""
        return f"{self.id}__{self.type}"

    @property
    def column_names(self) -> list[str]:
        return sorted([c.column_name for c in self.columns], key=lambda x: x or "")

    @property
    def columns_types(self) -> dict[str, str]:
        return {c.column_name: c.type for c in self.columns}

    @property
    def main_dttm_col(self) -> str:
        return "timestamp"

    @property
    def datasource_name(self) -> str:
        raise NotImplementedError()

    @property
    def connection(self) -> str | None:
        """String representing the context of the Datasource"""
        return None

    @property
    def schema(self) -> str | None:
        """String representing the schema of the Datasource (if it applies)"""
        return None

    @property
    def filterable_column_names(self) -> list[str]:
        return sorted([c.column_name for c in self.columns if c.filterable])

    @property
    def dttm_cols(self) -> list[str]:
        return []

    @property
    def url(self) -> str:
        return f"/{self.baselink}/edit/{self.id}"

    @property
    def explore_url(self) -> str:
        if self.default_endpoint:
            return self.default_endpoint
        return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"

    @property
    def column_formats(self) -> dict[str, str | None]:
        return {m.metric_name: m.d3format for m in self.metrics if m.d3format}

    @property
    def currency_formats(self) -> dict[str, dict[str, str | None] | None]:
        return {m.metric_name: m.currency_json for m in self.metrics if m.currency_json}

    def add_missing_metrics(self, metrics: list[SqlMetric]) -> None:
        existing_metrics = {m.metric_name for m in self.metrics}
        for metric in metrics:
            if metric.metric_name not in existing_metrics:
                metric.table_id = self.id
                self.metrics.append(metric)

    @property
    def short_data(self) -> dict[str, Any]:
        """Data representation of the datasource sent to the frontend"""
        return {
            "edit_url": self.url,
            "id": self.id,
            "uid": self.uid,
            "schema": self.schema or None,
            "name": self.name,
            "type": self.type,
            "connection": self.connection,
            "creator": str(self.created_by),
        }

    @property
    def select_star(self) -> str | None:
        pass

    @property
    def order_by_choices(self) -> list[tuple[str, str]]:
        choices = []
        # self.column_names return sorted column_names
        for column_name in self.column_names:
            column_name = str(column_name or "")
            choices.append(
                (json.dumps([column_name, True]), f"{column_name} " + __("[asc]"))
            )
            choices.append(
                (json.dumps([column_name, False]), f"{column_name} " + __("[desc]"))
            )
        return choices

    @property
    def verbose_map(self) -> dict[str, str]:
        verb_map = {"__timestamp": "Time"}
        verb_map.update(
            {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
        )
        verb_map.update(
            {o.column_name: o.verbose_name or o.column_name for o in self.columns}
        )
        return verb_map

    @property
    def data(self) -> dict[str, Any]:
        """Data representation of the datasource sent to the frontend"""
        return {
            # simple fields
            "id": self.id,
            "uid": self.uid,
            "column_formats": self.column_formats,
            "currency_formats": self.currency_formats,
            "description": self.description,
            "database": self.database.data,  # pylint: disable=no-member
            "default_endpoint": self.default_endpoint,
            "filter_select": self.filter_select_enabled,  # TODO deprecate
            "filter_select_enabled": self.filter_select_enabled,
            "name": self.name,
            "datasource_name": self.datasource_name,
            "table_name": self.datasource_name,
            "type": self.type,
            "schema": self.schema or None,
            "offset": self.offset,
            "cache_timeout": self.cache_timeout,
            "params": self.params,
            "perm": self.perm,
            "edit_url": self.url,
            # sqla-specific
            "sql": self.sql,
            # one to many
            "columns": [o.data for o in self.columns],
            "metrics": [o.data for o in self.metrics],
            # TODO deprecate, move logic to JS
            "order_by_choices": self.order_by_choices,
            "owners": [owner.id for owner in self.owners],
            "verbose_map": self.verbose_map,
            "select_star": self.select_star,
        }

    def data_for_slices(  # pylint: disable=too-many-locals
        self, slices: list[Slice]
    ) -> dict[str, Any]:
        """
        The representation of the datasource containing only the required data
        to render the provided slices.

        Used to reduce the payload when loading a dashboard.
        """
        data = self.data
        metric_names = set()
        column_names = set()
        for slc in slices:
            form_data = slc.form_data
            # pull out all required metrics from the form_data
            for metric_param in METRIC_FORM_DATA_PARAMS:
                for metric in utils.as_list(form_data.get(metric_param) or []):
                    metric_names.add(utils.get_metric_name(metric))
                    if utils.is_adhoc_metric(metric):
                        column_ = metric.get("column") or {}
                        if column_name := column_.get("column_name"):
                            column_names.add(column_name)

            # Columns used in query filters
            column_names.update(
                filter_["subject"]
                for filter_ in form_data.get("adhoc_filters") or []
                if filter_.get("clause") == "WHERE" and filter_.get("subject")
            )

            # columns used by Filter Box
            column_names.update(
                filter_config["column"]
                for filter_config in form_data.get("filter_configs") or []
                if "column" in filter_config
            )

            # for legacy dashboard imports which have the wrong query_context in them
            try:
                query_context = slc.get_query_context()
            except DatasetNotFoundError:
                query_context = None

            # legacy charts don't have query_context charts
            if query_context:
                column_names.update(
                    [
                        utils.get_column_name(column_)
                        for query in query_context.queries
                        for column_ in query.columns
                    ]
                    or []
                )
            else:
                _columns = [
                    utils.get_column_name(column_)
                    if utils.is_adhoc_column(column_)
                    else column_
                    for column_param in COLUMN_FORM_DATA_PARAMS
                    for column_ in utils.as_list(form_data.get(column_param) or [])
                ]
                column_names.update(_columns)

        filtered_metrics = [
            metric
            for metric in data["metrics"]
            if metric["metric_name"] in metric_names
        ]

        filtered_columns: list[Column] = []
        column_types: set[GenericDataType] = set()
        for column_ in data["columns"]:
            generic_type = column_.get("type_generic")
            if generic_type is not None:
                column_types.add(generic_type)
            if column_["column_name"] in column_names:
                filtered_columns.append(column_)

        data["column_types"] = list(column_types)
        del data["description"]
        data.update({"metrics": filtered_metrics})
        data.update({"columns": filtered_columns})
        verbose_map = {"__timestamp": "Time"}
        verbose_map.update(
            {
                metric["metric_name"]: metric["verbose_name"] or metric["metric_name"]
                for metric in filtered_metrics
            }
        )
        verbose_map.update(
            {
                column_["column_name"]: column_["verbose_name"]
                or column_["column_name"]
                for column_ in filtered_columns
            }
        )
        data["verbose_map"] = verbose_map

        return data

    @staticmethod
    def filter_values_handler(  # pylint: disable=too-many-arguments
        values: FilterValues | None,
        operator: str,
        target_generic_type: GenericDataType,
        target_native_type: str | None = None,
        is_list_target: bool = False,
        db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
        db_extra: dict[str, Any] | None = None,
    ) -> FilterValues | None:
        if values is None:
            return None

        def handle_single_value(value: FilterValue | None) -> FilterValue | None:
            if operator == utils.FilterOperator.TEMPORAL_RANGE:
                return value
            if (
                isinstance(value, (float, int))
                and target_generic_type == utils.GenericDataType.TEMPORAL
                and target_native_type is not None
                and db_engine_spec is not None
            ):
                value = db_engine_spec.convert_dttm(
                    target_type=target_native_type,
                    dttm=datetime.utcfromtimestamp(value / 1000),
                    db_extra=db_extra,
                )
                value = literal_column(value)
            if isinstance(value, str):
                value = value.strip("\t\n")

                if (
                    target_generic_type == utils.GenericDataType.NUMERIC
                    and operator
                    not in {
                        utils.FilterOperator.ILIKE,
                        utils.FilterOperator.LIKE,
                    }
                ):
                    # For backwards compatibility and edge cases
                    # where a column data type might have changed
                    return utils.cast_to_num(value)
                if value == NULL_STRING:
                    return None
                if value == EMPTY_STRING:
                    return ""
            if target_generic_type == utils.GenericDataType.BOOLEAN:
                return utils.cast_to_boolean(value)
            return value

        if isinstance(values, (list, tuple)):
            values = [handle_single_value(v) for v in values]  # type: ignore
        else:
            values = handle_single_value(values)
        if is_list_target and not isinstance(values, (tuple, list)):
            values = [values]  # type: ignore
        elif not is_list_target and isinstance(values, (tuple, list)):
            values = values[0] if values else None
        return values

    def external_metadata(self) -> list[ResultSetColumnType]:
        """Returns column information from the external system"""
        raise NotImplementedError()

    def get_query_str(self, query_obj: QueryObjectDict) -> str:
        """Returns a query as a string

        This is used to be displayed to the user so that they can
        understand what is taking place behind the scene"""
        raise NotImplementedError()

    def query(self, query_obj: QueryObjectDict) -> QueryResult:
        """Executes the query and returns a dataframe

        query_obj is a dictionary representing Superset's query interface.
        Should return a ``superset.models.helpers.QueryResult``
        """
        raise NotImplementedError()

    @staticmethod
    def default_query(qry: Query) -> Query:
        return qry

    def get_column(self, column_name: str | None) -> TableColumn | None:
        if not column_name:
            return None
        for col in self.columns:
            if col.column_name == column_name:
                return col
        return None

    @staticmethod
    def get_fk_many_from_list(
        object_list: list[Any],
        fkmany: list[Column],
        fkmany_class: builtins.type[TableColumn | SqlMetric],
        key_attr: str,
    ) -> list[Column]:
        """Update ORM one-to-many list from object list

        Used for syncing metrics and columns using the same code"""

        object_dict = {o.get(key_attr): o for o in object_list}

        # delete fks that have been removed
        fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict]

        # sync existing fks
        for fk in fkmany:
            obj = object_dict.get(getattr(fk, key_attr))
            if obj:
                for attr in fkmany_class.update_from_object_fields:
                    setattr(fk, attr, obj.get(attr))

        # create new fks
        new_fks = []
        orm_keys = [getattr(o, key_attr) for o in fkmany]
        for obj in object_list:
            key = obj.get(key_attr)
            if key not in orm_keys:
                del obj["id"]
                orm_kwargs = {}
                for k in obj:
                    if k in fkmany_class.update_from_object_fields and k in obj:
                        orm_kwargs[k] = obj[k]
                new_obj = fkmany_class(**orm_kwargs)
                new_fks.append(new_obj)
        fkmany += new_fks
        return fkmany

    def update_from_object(self, obj: dict[str, Any]) -> None:
        """Update datasource from a data structure

        The UI's table editor crafts a complex data structure that
        contains most of the datasource's properties as well as
        an array of metrics and columns objects. This method
        receives the object from the UI and syncs the datasource to
        match it. Since the fields are different for the different
        connectors, the implementation uses ``update_from_object_fields``
        which can be defined for each connector and
        defines which fields should be synced"""
        for attr in self.update_from_object_fields:
            setattr(self, attr, obj.get(attr))

        self.owners = obj.get("owners", [])

        # Syncing metrics
        metrics = (
            self.get_fk_many_from_list(
                obj["metrics"], self.metrics, SqlMetric, "metric_name"
            )
            if "metrics" in obj
            else []
        )
        self.metrics = metrics

        # Syncing columns
        self.columns = (
            self.get_fk_many_from_list(
                obj["columns"], self.columns, TableColumn, "column_name"
            )
            if "columns" in obj
            else []
        )

    def get_extra_cache_keys(
        self,
        query_obj: QueryObjectDict,  # pylint: disable=unused-argument
    ) -> list[Hashable]:
        """If a datasource needs to provide additional keys for calculation of
        cache keys, those can be provided via this method

        :param query_obj: The dict representation of a query object
        :return: list of keys
        """
        return []

    def __hash__(self) -> int:
        return hash(self.uid)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, BaseDatasource):
            return NotImplemented
        return self.uid == other.uid

    def raise_for_access(self) -> None:
        """
        Raise an exception if the user cannot access the resource.

        :raises SupersetSecurityException: If the user cannot access the resource
        """

        security_manager.raise_for_access(datasource=self)

    @classmethod
    def get_datasource_by_name(
        cls, datasource_name: str, schema: str, database_name: str
    ) -> BaseDatasource | None:
        raise NotImplementedError()


class AnnotationDatasource(BaseDatasource):
    """Dummy object so we can query annotations using 'Viz' objects just like
    regular datasources.
    """

    cache_timeout = 0
    changed_on = None
    type = "annotation"
    column_names = [
        "created_on",
        "changed_on",
        "id",
        "start_dttm",
        "end_dttm",
        "layer_id",
        "short_descr",
        "long_descr",
        "json_metadata",
        "created_by_fk",
        "changed_by_fk",
    ]

    def query(self, query_obj: QueryObjectDict) -> QueryResult:
        error_message = None
        qry = db.session.query(Annotation)
        qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"])
        if query_obj["from_dttm"]:
            qry = qry.filter(Annotation.start_dttm >= query_obj["from_dttm"])
        if query_obj["to_dttm"]:
            qry = qry.filter(Annotation.end_dttm <= query_obj["to_dttm"])
        status = QueryStatus.SUCCESS
        try:
            df = pd.read_sql_query(qry.statement, db.engine)
        except Exception as ex:  # pylint: disable=broad-except
            df = pd.DataFrame()
            status = QueryStatus.FAILED
            logger.exception(ex)
            error_message = utils.error_msg_from_exception(ex)
        return QueryResult(
            status=status,
            df=df,
            duration=timedelta(0),
            query="",
            error_message=error_message,
        )

    def get_query_str(self, query_obj: QueryObjectDict) -> str:
        raise NotImplementedError()

    def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
        raise NotImplementedError()


class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model):
    """ORM object for table columns, each table can have multiple columns"""

    __tablename__ = "table_columns"
    __table_args__ = (UniqueConstraint("table_id", "column_name"),)

    id = Column(Integer, primary_key=True)
    column_name = Column(String(255), nullable=False)
    verbose_name = Column(String(1024))
    is_active = Column(Boolean, default=True)
    type = Column(Text)
    advanced_data_type = Column(String(255))
    groupby = Column(Boolean, default=True)
    filterable = Column(Boolean, default=True)
    description = Column(MediumText())
    table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
    is_dttm = Column(Boolean, default=False)
    expression = Column(MediumText())
    python_date_format = Column(String(255))
    extra = Column(Text)

    table: Mapped[SqlaTable] = relationship(
        "SqlaTable",
        back_populates="columns",
    )

    export_fields = [
        "table_id",
        "column_name",
        "verbose_name",
        "is_dttm",
        "is_active",
        "type",
        "advanced_data_type",
        "groupby",
        "filterable",
        "expression",
        "description",
        "python_date_format",
        "extra",
    ]

    update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
    export_parent = "table"

    def __init__(self, **kwargs: Any) -> None:
        """
        Construct a TableColumn object.

        Historically a TableColumn object (from an ORM perspective) was tightly bound to
        a SqlaTable object, however with the introduction of the Query datasource this
        is no longer true, i.e., the SqlaTable relationship is optional.

        Now the TableColumn is either directly associated with the Database object (
        which is unknown to the ORM) or indirectly via the SqlaTable object (courtesy of
        the ORM) depending on the context.
        """

        self._database: Database | None = kwargs.pop("database", None)
        super().__init__(**kwargs)

    @reconstructor
    def init_on_load(self) -> None:
        """
        Construct a TableColumn object when invoked via the SQLAlchemy ORM.
        """

        self._database = None

    def __repr__(self) -> str:
        return str(self.column_name)

    @property
    def is_boolean(self) -> bool:
        """
        Check if the column has a boolean datatype.
        """
        return self.type_generic == GenericDataType.BOOLEAN

    @property
    def is_numeric(self) -> bool:
        """
        Check if the column has a numeric datatype.
        """
        return self.type_generic == GenericDataType.NUMERIC

    @property
    def is_string(self) -> bool:
        """
        Check if the column has a string datatype.
        """
        return self.type_generic == GenericDataType.STRING

    @property
    def is_temporal(self) -> bool:
        """
        Check if the column has a temporal datatype. If column has been set as
        temporal/non-temporal (`is_dttm` is True or False respectively), return that
        value. This usually happens during initial metadata fetching or when a column
        is manually set as temporal (for this `python_date_format` needs to be set).
        """
        if self.is_dttm is not None:
            return self.is_dttm
        return self.type_generic == GenericDataType.TEMPORAL

    @property
    def database(self) -> Database:
        return self.table.database if self.table else self._database

    @property
    def db_engine_spec(self) -> builtins.type[BaseEngineSpec]:
        return self.database.db_engine_spec

    @property
    def db_extra(self) -> dict[str, Any]:
        return self.database.get_extra()

    @property
    def type_generic(self) -> utils.GenericDataType | None:
        if self.is_dttm:
            return GenericDataType.TEMPORAL

        return (
            column_spec.generic_type
            if (
                column_spec := self.db_engine_spec.get_column_spec(
                    self.type,
                    db_extra=self.db_extra,
                )
            )
            else None
        )

    def get_sqla_col(
        self,
        label: str | None = None,
        template_processor: BaseTemplateProcessor | None = None,
    ) -> Column:
        label = label or self.column_name
        db_engine_spec = self.db_engine_spec
        column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
        type_ = column_spec.sqla_type if column_spec else None
        if expression := self.expression:
            if template_processor:
                expression = template_processor.process_template(expression)
            col = literal_column(expression, type_=type_)
        else:
            col = column(self.column_name, type_=type_)
        col = self.database.make_sqla_column_compatible(col, label)
        return col

    @property
    def datasource(self) -> RelationshipProperty:
        return self.table

    def get_timestamp_expression(
        self,
        time_grain: str | None,
        label: str | None = None,
        template_processor: BaseTemplateProcessor | None = None,
    ) -> TimestampExpression | Label:
        """
        Return a SQLAlchemy Core element representation of self to be used in a query.

        :param time_grain: Optional time grain, e.g. P1Y
        :param label: alias/label that column is expected to have
        :param template_processor: template processor
        :return: A TimeExpression object wrapped in a Label if supported by db
        """
        label = label or utils.DTTM_ALIAS

        pdf = self.python_date_format
        is_epoch = pdf in ("epoch_s", "epoch_ms")
        column_spec = self.db_engine_spec.get_column_spec(
            self.type, db_extra=self.db_extra
        )
        type_ = column_spec.sqla_type if column_spec else DateTime
        if not self.expression and not time_grain and not is_epoch:
            sqla_col = column(self.column_name, type_=type_)
            return self.database.make_sqla_column_compatible(sqla_col, label)
        if expression := self.expression:
            if template_processor:
                expression = template_processor.process_template(expression)
            col = literal_column(expression, type_=type_)
        else:
            col = column(self.column_name, type_=type_)
        time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
        return self.database.make_sqla_column_compatible(time_expr, label)

    @property
    def data(self) -> dict[str, Any]:
        attrs = (
            "advanced_data_type",
            "certification_details",
            "certified_by",
            "column_name",
            "description",
            "expression",
            "filterable",
            "groupby",
            "id",
            "is_certified",
            "is_dttm",
            "python_date_format",
            "type",
            "type_generic",
            "verbose_name",
            "warning_markdown",
        )

        return {s: getattr(self, s) for s in attrs if hasattr(self, s)}


class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model):
    """ORM object for metrics, each table can have multiple metrics"""

    __tablename__ = "sql_metrics"
    __table_args__ = (UniqueConstraint("table_id", "metric_name"),)

    id = Column(Integer, primary_key=True)
    metric_name = Column(String(255), nullable=False)
    verbose_name = Column(String(1024))
    metric_type = Column(String(32))
    description = Column(MediumText())
    d3format = Column(String(128))
    currency = Column(String(128))
    warning_text = Column(Text)
    table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
    expression = Column(MediumText(), nullable=False)
    extra = Column(Text)

    table: Mapped[SqlaTable] = relationship(
        "SqlaTable",
        back_populates="metrics",
    )

    export_fields = [
        "metric_name",
        "verbose_name",
        "metric_type",
        "table_id",
        "expression",
        "description",
        "d3format",
        "currency",
        "extra",
        "warning_text",
    ]
    update_from_object_fields = list(s for s in export_fields if s != "table_id")
    export_parent = "table"

    def __repr__(self) -> str:
        return str(self.metric_name)

    def get_sqla_col(
        self,
        label: str | None = None,
        template_processor: BaseTemplateProcessor | None = None,
    ) -> Column:
        label = label or self.metric_name
        expression = self.expression
        if template_processor:
            expression = template_processor.process_template(expression)

        sqla_col: ColumnClause = literal_column(expression)
        return self.table.database.make_sqla_column_compatible(sqla_col, label)

    @property
    def perm(self) -> str | None:
        return (
            ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
                obj=self, parent_name=self.table.full_name
            )
            if self.table
            else None
        )

    def get_perm(self) -> str | None:
        return self.perm

    @property
    def currency_json(self) -> dict[str, str | None] | None:
        try:
            return json.loads(self.currency or "{}") or None
        except (TypeError, JSONDecodeError) as exc:
            logger.error(
                "Unable to load currency json: %r. Leaving empty.", exc, exc_info=True
            )
            return None

    @property
    def data(self) -> dict[str, Any]:
        attrs = (
            "certification_details",
            "certified_by",
            "currency",
            "d3format",
            "description",
            "expression",
            "id",
            "is_certified",
            "metric_name",
            "warning_markdown",
            "warning_text",
            "verbose_name",
        )

        return {s: getattr(self, s) for s in attrs}


sqlatable_user = DBTable(
    "sqlatable_user",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")),
    Column("table_id", Integer, ForeignKey("tables.id", ondelete="CASCADE")),
)


def _process_sql_expression(
    expression: str | None,
    database_id: int,
    schema: str,
    template_processor: BaseTemplateProcessor | None = None,
) -> str | None:
    if template_processor and expression:
        expression = template_processor.process_template(expression)
    if expression:
        try:
            expression = validate_adhoc_subquery(
                expression,
                database_id,
                schema,
            )
            expression = sanitize_clause(expression)
        except (QueryClauseValidationException, SupersetSecurityException) as ex:
            raise QueryObjectValidationError(ex.message) from ex
    return expression


class SqlaTable(
    Model,
    BaseDatasource,
    ExploreMixin,
):  # pylint: disable=too-many-public-methods
    """An ORM object for SqlAlchemy table references"""

    type = "table"
    query_language = "sql"
    is_rls_supported = True
    columns: Mapped[list[TableColumn]] = relationship(
        TableColumn,
        back_populates="table",
        cascade="all, delete-orphan",
        passive_deletes=True,
    )
    metrics: Mapped[list[SqlMetric]] = relationship(
        SqlMetric,
        back_populates="table",
        cascade="all, delete-orphan",
        passive_deletes=True,
    )
    metric_class = SqlMetric
    column_class = TableColumn
    owner_class = security_manager.user_model

    __tablename__ = "tables"

    # Note this uniqueness constraint is not part of the physical schema, i.e., it does
    # not exist in the migrations, but is required by `import_from_dict` to ensure the
    # correct filters are applied in order to identify uniqueness.
    #
    # The reason it does not physically exist is MySQL, PostgreSQL, etc. have a
    # different interpretation of uniqueness when it comes to NULL which is problematic
    # given the schema is optional.
    __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)

    table_name = Column(String(250), nullable=False)
    main_dttm_col = Column(String(250))
    database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
    fetch_values_predicate = Column(Text)
    owners = relationship(owner_class, secondary=sqlatable_user, backref="tables")
    database: Database = relationship(
        "Database",
        backref=backref("tables", cascade="all, delete-orphan"),
        foreign_keys=[database_id],
    )
    schema = Column(String(255))
    catalog = Column(String(256), nullable=True, default=None)
    sql = Column(MediumText())
    is_sqllab_view = Column(Boolean, default=False)
    template_params = Column(Text)
    extra = Column(Text)
    normalize_columns = Column(Boolean, default=False)
    always_filter_main_dttm = Column(Boolean, default=False)

    baselink = "tablemodelview"

    export_fields = [
        "table_name",
        "main_dttm_col",
        "description",
        "default_endpoint",
        "database_id",
        "offset",
        "cache_timeout",
        "schema",
        "sql",
        "params",
        "template_params",
        "filter_select_enabled",
        "fetch_values_predicate",
        "extra",
        "normalize_columns",
        "always_filter_main_dttm",
    ]
    update_from_object_fields = [f for f in export_fields if f != "database_id"]
    export_parent = "database"
    export_children = ["metrics", "columns"]

    sqla_aggregations = {
        "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
        "COUNT": sa.func.COUNT,
        "SUM": sa.func.SUM,
        "AVG": sa.func.AVG,
        "MIN": sa.func.MIN,
        "MAX": sa.func.MAX,
    }

    def __repr__(self) -> str:  # pylint: disable=invalid-repr-returned
        return self.name

    @property
    def db_extra(self) -> dict[str, Any]:
        return self.database.get_extra()

    @staticmethod
    def _apply_cte(sql: str, cte: str | None) -> str:
        """
        Append a CTE before the SELECT statement if defined

        :param sql: SELECT statement
        :param cte: CTE statement
        :return:
        """
        if cte:
            sql = f"{cte}\n{sql}"
        return sql

    @property
    def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
        return self.database.db_engine_spec

    @property
    def changed_by_name(self) -> str:
        if not self.changed_by:
            return ""
        return str(self.changed_by)

    @property
    def connection(self) -> str:
        return str(self.database)

    @property
    def description_markeddown(self) -> str:
        return utils.markdown(self.description)

    @property
    def datasource_name(self) -> str:
        return self.table_name

    @property
    def datasource_type(self) -> str:
        return self.type

    @property
    def database_name(self) -> str:
        return self.database.name

    @classmethod
    def get_datasource_by_name(
        cls,
        datasource_name: str,
        schema: str | None,
        database_name: str,
    ) -> SqlaTable | None:
        schema = schema or None
        query = (
            db.session.query(cls)
            .join(Database)
            .filter(cls.table_name == datasource_name)
            .filter(Database.database_name == database_name)
        )
        # Handling schema being '' or None, which is easier to handle
        # in python than in the SQLA query in a multi-dialect way
        for tbl in query.all():
            if schema == (tbl.schema or None):
                return tbl
        return None

    @property
    def link(self) -> Markup:
        name = escape(self.name)
        anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
        return Markup(anchor)

    def get_schema_perm(self) -> str | None:
        """Returns schema permission if present, database one otherwise."""
        return security_manager.get_schema_perm(self.database, self.schema or None)

    def get_perm(self) -> str:
        """
        Return this dataset permission name
        :return: dataset permission name
        :raises DatasetInvalidPermissionEvaluationException: When database is missing
        """
        if self.database is None:
            raise DatasetInvalidPermissionEvaluationException()
        return f"[{self.database}].[{self.table_name}](id:{self.id})"

    @hybrid_property
    def name(self) -> str:  # pylint: disable=invalid-overridden-method
        return self.schema + "." + self.table_name if self.schema else self.table_name

    @property
    def full_name(self) -> str:
        return utils.get_datasource_full_name(
            self.database, self.table_name, schema=self.schema
        )

    @property
    def dttm_cols(self) -> list[str]:
        l = [c.column_name for c in self.columns if c.is_dttm]  # noqa: E741
        if self.main_dttm_col and self.main_dttm_col not in l:
            l.append(self.main_dttm_col)
        return l

    @property
    def num_cols(self) -> list[str]:
        return [c.column_name for c in self.columns if c.is_numeric]

    @property
    def any_dttm_col(self) -> str | None:
        cols = self.dttm_cols
        return cols[0] if cols else None

    @property
    def html(self) -> str:
        df = pd.DataFrame((c.column_name, c.type) for c in self.columns)
        df.columns = ["field", "type"]
        return df.to_html(
            index=False,
            classes=("dataframe table table-striped table-bordered " "table-condensed"),
        )

    @property
    def sql_url(self) -> str:
        return self.database.sql_url + "?table_name=" + str(self.table_name)

    def external_metadata(self) -> list[ResultSetColumnType]:
        # todo(yongjie): create a physical table column type in a separate PR
        if self.sql:
            return get_virtual_table_metadata(dataset=self)
        return get_physical_table_metadata(
            database=self.database,
            table=Table(self.table_name, self.schema or None, self.catalog),
            normalize_columns=self.normalize_columns,
        )

    @property
    def time_column_grains(self) -> dict[str, Any]:
        return {
            "time_columns": self.dttm_cols,
            "time_grains": [grain.name for grain in self.database.grains()],
        }

    @property
    def select_star(self) -> str | None:
        # show_cols and latest_partition set to false to avoid
        # the expensive cost of inspecting the DB
        return self.database.select_star(
            Table(self.table_name, self.schema or None, self.catalog),
            show_cols=False,
            latest_partition=False,
        )

    @property
    def health_check_message(self) -> str | None:
        check = config["DATASET_HEALTH_CHECK"]
        return check(self) if check else None

    @property
    def granularity_sqla(self) -> list[tuple[Any, Any]]:
        return utils.choicify(self.dttm_cols)

    @property
    def time_grain_sqla(self) -> list[tuple[Any, Any]]:
        return [(g.duration, g.name) for g in self.database.grains() or []]

    @property
    def data(self) -> dict[str, Any]:
        data_ = super().data
        if self.type == "table":
            data_["granularity_sqla"] = self.granularity_sqla
            data_["time_grain_sqla"] = self.time_grain_sqla
            data_["main_dttm_col"] = self.main_dttm_col
            data_["fetch_values_predicate"] = self.fetch_values_predicate
            data_["template_params"] = self.template_params
            data_["is_sqllab_view"] = self.is_sqllab_view
            data_["health_check_message"] = self.health_check_message
            data_["extra"] = self.extra
            data_["owners"] = self.owners_data
            data_["always_filter_main_dttm"] = self.always_filter_main_dttm
            data_["normalize_columns"] = self.normalize_columns
        return data_

    @property
    def extra_dict(self) -> dict[str, Any]:
        try:
            return json.loads(self.extra)
        except (TypeError, json.JSONDecodeError):
            return {}

    def get_fetch_values_predicate(
        self,
        template_processor: BaseTemplateProcessor | None = None,
    ) -> TextClause:
        fetch_values_predicate = self.fetch_values_predicate
        if template_processor:
            fetch_values_predicate = template_processor.process_template(
                fetch_values_predicate
            )
        try:
            return self.text(fetch_values_predicate)
        except TemplateError as ex:
            raise QueryObjectValidationError(
                _(
                    "Error in jinja expression in fetch values predicate: %(msg)s",
                    msg=ex.message,
                )
            ) from ex

    def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
        return get_template_processor(table=self, database=self.database, **kwargs)

    def get_query_str(self, query_obj: QueryObjectDict) -> str:
        query_str_ext = self.get_query_str_extended(query_obj)
        all_queries = query_str_ext.prequeries + [query_str_ext.sql]
        return ";\n\n".join(all_queries) + ";"

    def get_sqla_table(self) -> TableClause:
        tbl = table(self.table_name)
        if self.schema:
            tbl.schema = self.schema
        return tbl

    def get_from_clause(
        self, template_processor: BaseTemplateProcessor | None = None
    ) -> tuple[TableClause | Alias, str | None]:
        """
        Return where to select the columns and metrics from. Either a physical table
        or a virtual table with it's own subquery. If the FROM is referencing a
        CTE, the CTE is returned as the second value in the return tuple.
        """
        if not self.is_virtual:
            return self.get_sqla_table(), None

        from_sql = self.get_rendered_sql(template_processor)
        parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
        if not (
            parsed_query.is_unknown()
            or self.db_engine_spec.is_readonly_query(parsed_query)
        ):
            raise QueryObjectValidationError(
                _("Virtual dataset query must be read-only")
            )

        cte = self.db_engine_spec.get_cte_query(from_sql)
        from_clause = (
            table(self.db_engine_spec.cte_alias)
            if cte
            else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
        )

        return from_clause, cte

    def adhoc_metric_to_sqla(
        self,
        metric: AdhocMetric,
        columns_by_name: dict[str, TableColumn],
        template_processor: BaseTemplateProcessor | None = None,
    ) -> ColumnElement:
        """
        Turn an adhoc metric into a sqlalchemy column.

        :param dict metric: Adhoc metric definition
        :param dict columns_by_name: Columns for the current table
        :param template_processor: template_processor instance
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        expression_type = metric.get("expressionType")
        label = utils.get_metric_name(metric)

        if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
            metric_column = metric.get("column") or {}
            column_name = cast(str, metric_column.get("column_name"))
            table_column: TableColumn | None = columns_by_name.get(column_name)
            if table_column:
                sqla_column = table_column.get_sqla_col(
                    template_processor=template_processor
                )
            else:
                sqla_column = column(column_name)
            sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
        elif expression_type == utils.AdhocMetricExpressionType.SQL:
            expression = _process_sql_expression(
                expression=metric["sqlExpression"],
                database_id=self.database_id,
                schema=self.schema,
                template_processor=template_processor,
            )
            sqla_metric = literal_column(expression)
        else:
            raise QueryObjectValidationError("Adhoc metric expressionType is invalid")

        return self.make_sqla_column_compatible(sqla_metric, label)

    def adhoc_column_to_sqla(  # pylint: disable=too-many-locals
        self,
        col: AdhocColumn,
        force_type_check: bool = False,
        template_processor: BaseTemplateProcessor | None = None,
    ) -> ColumnElement:
        """
        Turn an adhoc column into a sqlalchemy column.

        :param col: Adhoc column definition
        :param force_type_check: Should the column type be checked in the db.
               This is needed to validate if a filter with an adhoc column
               is applicable.
        :param template_processor: template_processor instance
        :returns: The metric defined as a sqlalchemy column
        :rtype: sqlalchemy.sql.column
        """
        label = utils.get_column_name(col)
        expression = _process_sql_expression(
            expression=col["sqlExpression"],
            database_id=self.database_id,
            schema=self.schema,
            template_processor=template_processor,
        )
        time_grain = col.get("timeGrain")
        has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
        is_dttm = False
        pdf = None
        if col_in_metadata := self.get_column(expression):
            sqla_column = col_in_metadata.get_sqla_col(
                template_processor=template_processor
            )
            is_dttm = col_in_metadata.is_temporal
            pdf = col_in_metadata.python_date_format
        else:
            sqla_column = literal_column(expression)
            if has_timegrain or force_type_check:
                try:
                    # probe adhoc column type
                    tbl, _ = self.get_from_clause(template_processor)
                    qry = sa.select([sqla_column]).limit(1).select_from(tbl)
                    sql = self.database.compile_sqla_query(qry)
                    col_desc = get_columns_description(
                        self.database,
                        self.catalog,
                        self.schema or None,
                        sql,
                    )
                    if not col_desc:
                        raise SupersetGenericDBErrorException("Column not found")
                    is_dttm = col_desc[0]["is_dttm"]  # type: ignore
                except SupersetGenericDBErrorException as ex:
                    raise ColumnNotFoundException(message=str(ex)) from ex

        if is_dttm and has_timegrain:
            sqla_column = self.db_engine_spec.get_timestamp_expr(
                col=sqla_column,
                pdf=pdf,
                time_grain=time_grain,
            )
        return self.make_sqla_column_compatible(sqla_column, label)

    def make_orderby_compatible(
        self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
    ) -> None:
        """
        If needed, make sure aliases for selected columns are not used in
        `ORDER BY`.

        In some databases (e.g. Presto), `ORDER BY` clause is not able to
        automatically pick the source column if a `SELECT` clause alias is named
        the same as a source column. In this case, we update the SELECT alias to
        another name to avoid the conflict.
        """
        if self.db_engine_spec.allows_alias_to_source_column:
            return

        def is_alias_used_in_orderby(col: ColumnElement) -> bool:
            if not isinstance(col, Label):
                return False
            regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
            return any(regexp.search(str(x)) for x in orderby_exprs)

        # Iterate through selected columns, if column alias appears in orderby
        # use another `alias`. The final output columns will still use the
        # original names, because they are updated by `labels_expected` after
        # querying.
        for col in select_exprs:
            if is_alias_used_in_orderby(col):
                col.name = f"{col.name}__"

    def get_sqla_row_level_filters(
        self,
        template_processor: BaseTemplateProcessor,
    ) -> list[TextClause]:
        """
        Return the appropriate row level security filters for this table and the
        current user. A custom username can be passed when the user is not present in the
        Flask global namespace.

        :param template_processor: The template processor to apply to the filters.
        :returns: A list of SQL clauses to be ANDed together.
        """
        all_filters: list[TextClause] = []
        filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
        try:
            for filter_ in security_manager.get_rls_filters(self):
                clause = self.text(
                    f"({template_processor.process_template(filter_.clause)})"
                )
                if filter_.group_key:
                    filter_groups[filter_.group_key].append(clause)
                else:
                    all_filters.append(clause)

            if is_feature_enabled("EMBEDDED_SUPERSET"):
                for rule in security_manager.get_guest_rls_filters(self):
                    clause = self.text(
                        f"({template_processor.process_template(rule['clause'])})"
                    )
                    all_filters.append(clause)

            grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
            all_filters.extend(grouped_filters)
            return all_filters
        except TemplateError as ex:
            raise QueryObjectValidationError(
                _(
                    "Error in jinja expression in RLS filters: %(msg)s",
                    msg=ex.message,
                )
            ) from ex

    def text(self, clause: str) -> TextClause:
        return self.db_engine_spec.get_text_clause(clause)

    def _get_series_orderby(
        self,
        series_limit_metric: Metric,
        metrics_by_name: dict[str, SqlMetric],
        columns_by_name: dict[str, TableColumn],
        template_processor: BaseTemplateProcessor | None = None,
    ) -> Column:
        if utils.is_adhoc_metric(series_limit_metric):
            assert isinstance(series_limit_metric, dict)
            ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name)
        elif (
            isinstance(series_limit_metric, str)
            and series_limit_metric in metrics_by_name
        ):
            ob = metrics_by_name[series_limit_metric].get_sqla_col(
                template_processor=template_processor
            )
        else:
            raise QueryObjectValidationError(
                _("Metric '%(metric)s' does not exist", metric=series_limit_metric)
            )
        return ob

    def _normalize_prequery_result_type(
        self,
        row: pd.Series,
        dimension: str,
        columns_by_name: dict[str, TableColumn],
    ) -> str | int | float | bool | Text:
        """
        Convert a prequery result type to its equivalent Python type.

        Some databases like Druid will return timestamps as strings, but do not perform
        automatic casting when comparing these strings to a timestamp. For cases like
        this we convert the value via the appropriate SQL transform.

        :param row: A prequery record
        :param dimension: The dimension name
        :param columns_by_name: The mapping of columns by name
        :return: equivalent primitive python type
        """

        value = row[dimension]

        if isinstance(value, np.generic):
            value = value.item()

        column_ = columns_by_name[dimension]
        db_extra: dict[str, Any] = self.database.get_extra()

        if column_.type and column_.is_temporal and isinstance(value, str):
            sql = self.db_engine_spec.convert_dttm(
                column_.type, dateutil.parser.parse(value), db_extra=db_extra
            )

            if sql:
                value = self.text(sql)

        return value

    def _get_top_groups(
        self,
        df: pd.DataFrame,
        dimensions: list[str],
        groupby_exprs: dict[str, Any],
        columns_by_name: dict[str, TableColumn],
    ) -> ColumnElement:
        groups = []
        for _unused, row in df.iterrows():
            group = []
            for dimension in dimensions:
                value = self._normalize_prequery_result_type(
                    row,
                    dimension,
                    columns_by_name,
                )

                group.append(groupby_exprs[dimension] == value)
            groups.append(and_(*group))

        return or_(*groups)

    def query(self, query_obj: QueryObjectDict) -> QueryResult:
        qry_start_dttm = datetime.now()
        query_str_ext = self.get_query_str_extended(query_obj)
        sql = query_str_ext.sql
        status = QueryStatus.SUCCESS
        errors = None
        error_message = None

        def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
            """
            Some engines change the case or generate bespoke column names, either by
            default or due to lack of support for aliasing. This function ensures that
            the column names in the DataFrame correspond to what is expected by
            the viz components.

            Sometimes a query may also contain only order by columns that are not used
            as metrics or groupby columns, but need to present in the SQL `select`,
            filtering by `labels_expected` make sure we only return columns users want.

            :param df: Original DataFrame returned by the engine
            :return: Mutated DataFrame
            """
            labels_expected = query_str_ext.labels_expected
            if df is not None and not df.empty:
                if len(df.columns) < len(labels_expected):
                    raise QueryObjectValidationError(
                        _("Db engine did not return all queried columns")
                    )
                if len(df.columns) > len(labels_expected):
                    df = df.iloc[:, 0 : len(labels_expected)]
                df.columns = labels_expected
            return df

        try:
            df = self.database.get_df(
                sql, self.schema or None, mutator=assign_column_label
            )
        except (SupersetErrorException, SupersetErrorsException) as ex:
            # SupersetError(s) exception should not be captured; instead, they should
            # bubble up to the Flask error handler so they are returned as proper SIP-40
            # errors. This is particularly important for database OAuth2, see SIP-85.
            raise ex
        except Exception as ex:  # pylint: disable=broad-except
            # TODO (betodealmeida): review exception handling while querying the external
            # database. Ideally we'd expect and handle external database error, but
            # everything else / the default should be to let things bubble up.
            df = pd.DataFrame()
            status = QueryStatus.FAILED
            logger.warning(
                "Query %s on schema %s failed", sql, self.schema, exc_info=True
            )
            db_engine_spec = self.db_engine_spec
            errors = [
                dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex)
            ]
            error_message = utils.error_msg_from_exception(ex)

        return QueryResult(
            applied_template_filters=query_str_ext.applied_template_filters,
            applied_filter_columns=query_str_ext.applied_filter_columns,
            rejected_filter_columns=query_str_ext.rejected_filter_columns,
            status=status,
            df=df,
            duration=datetime.now() - qry_start_dttm,
            query=sql,
            errors=errors,
            error_message=error_message,
        )

    def get_sqla_table_object(self) -> Table:
        return self.database.get_table(
            Table(
                self.table_name,
                self.schema or None,
                self.catalog,
            )
        )

    def fetch_metadata(self, commit: bool = True) -> MetadataResult:
        """
        Fetches the metadata for the table and merges it in

        :param commit: should the changes be committed or not.
        :return: Tuple with lists of added, removed and modified column names.
        """
        new_columns = self.external_metadata()
        metrics = [
            SqlMetric(**metric)
            for metric in self.database.get_metrics(
                Table(
                    self.table_name,
                    self.schema or None,
                    self.catalog,
                )
            )
        ]
        any_date_col = None
        db_engine_spec = self.db_engine_spec

        # If no `self.id`, then this is a new table, no need to fetch columns
        # from db.  Passing in `self.id` to query will actually automatically
        # generate a new id, which can be tricky during certain transactions.
        old_columns = (
            (
                db.session.query(TableColumn)
                .filter(TableColumn.table_id == self.id)
                .all()
            )
            if self.id
            else self.columns
        )

        old_columns_by_name: dict[str, TableColumn] = {
            col.column_name: col for col in old_columns
        }
        results = MetadataResult(
            removed=[
                col
                for col in old_columns_by_name
                if col not in {col["column_name"] for col in new_columns}
            ]
        )

        # clear old columns before adding modified columns back
        columns = []
        for col in new_columns:
            old_column = old_columns_by_name.pop(col["column_name"], None)
            if not old_column:
                results.added.append(col["column_name"])
                new_column = TableColumn(
                    column_name=col["column_name"],
                    type=col["type"],
                    table=self,
                )
                new_column.is_dttm = new_column.is_temporal
                db_engine_spec.alter_new_orm_column(new_column)
            else:
                new_column = old_column
                if new_column.type != col["type"]:
                    results.modified.append(col["column_name"])
                new_column.type = col["type"]
                new_column.expression = ""
            new_column.groupby = True
            new_column.filterable = True
            columns.append(new_column)
            if not any_date_col and new_column.is_temporal:
                any_date_col = col["column_name"]

        # add back calculated (virtual) columns
        columns.extend([col for col in old_columns if col.expression])
        self.columns = columns

        if not self.main_dttm_col:
            self.main_dttm_col = any_date_col
        self.add_missing_metrics(metrics)

        # Apply config supplied mutations.
        config["SQLA_TABLE_MUTATOR"](self)

        db.session.merge(self)
        if commit:
            db.session.commit()
        return results

    @classmethod
    def query_datasources_by_name(
        cls,
        database: Database,
        datasource_name: str,
        schema: str | None = None,
    ) -> list[SqlaTable]:
        query = (
            db.session.query(cls)
            .filter_by(database_id=database.id)
            .filter_by(table_name=datasource_name)
        )
        if schema:
            query = query.filter_by(schema=schema)
        return query.all()

    @classmethod
    def query_datasources_by_permissions(  # pylint: disable=invalid-name
        cls,
        database: Database,
        permissions: set[str],
        schema_perms: set[str],
    ) -> list[SqlaTable]:
        # TODO(hughhhh): add unit test
        return (
            db.session.query(cls)
            .filter_by(database_id=database.id)
            .filter(
                or_(
                    SqlaTable.perm.in_(permissions),
                    SqlaTable.schema_perm.in_(schema_perms),
                )
            )
            .all()
        )

    @classmethod
    def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable:
        """Returns SqlaTable with columns and metrics."""
        return (
            db.session.query(cls)
            .options(
                sa.orm.subqueryload(cls.columns),
                sa.orm.subqueryload(cls.metrics),
            )
            .filter_by(id=datasource_id)
            .one()
        )

    @classmethod
    def get_all_datasources(cls) -> list[SqlaTable]:
        qry = db.session.query(cls)
        qry = cls.default_query(qry)
        return qry.all()

    @staticmethod
    def default_query(qry: Query) -> Query:
        return qry.filter_by(is_sqllab_view=False)

    def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool:
        """
        Detects the presence of calls to `ExtraCache` methods in items in query_obj that
        can be templated. If any are present, the query must be evaluated to extract
        additional keys for the cache key. This method is needed to avoid executing the
        template code unnecessarily, as it may contain expensive calls, e.g. to extract
        the latest partition of a database.

        :param query_obj: query object to analyze
        :return: True if there are call(s) to an `ExtraCache` method, False otherwise
        """
        templatable_statements: list[str] = []
        if self.sql:
            templatable_statements.append(self.sql)
        if self.fetch_values_predicate:
            templatable_statements.append(self.fetch_values_predicate)
        extras = query_obj.get("extras", {})
        if "where" in extras:
            templatable_statements.append(extras["where"])
        if "having" in extras:
            templatable_statements.append(extras["having"])
        if self.is_rls_supported:
            templatable_statements += [
                f.clause for f in security_manager.get_rls_filters(self)
            ]
        for statement in templatable_statements:
            if ExtraCache.regex.search(statement):
                return True
        return False

    def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
        """
        The cache key of a SqlaTable needs to consider any keys added by the parent
        class and any keys added via `ExtraCache`.

        :param query_obj: query object to analyze
        :return: The extra cache keys
        """
        extra_cache_keys = super().get_extra_cache_keys(query_obj)
        if self.has_extra_cache_key_calls(query_obj):
            sqla_query = self.get_sqla_query(**query_obj)
            extra_cache_keys += sqla_query.extra_cache_keys
        return extra_cache_keys

    @property
    def quote_identifier(self) -> Callable[[str], str]:
        return self.database.quote_identifier

    @staticmethod
    def before_update(
        mapper: Mapper,
        connection: Connection,
        target: SqlaTable,
    ) -> None:
        """
        Note this listener is called when any fields are being updated

        :param mapper: The table mapper
        :param connection: The DB-API connection
        :param target: The mapped instance being persisted
        :raises Exception: If the target table is not unique
        """
        target.load_database()
        security_manager.dataset_before_update(mapper, connection, target)

    @staticmethod
    def update_column(  # pylint: disable=unused-argument
        mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
    ) -> None:
        """
        :param mapper: Unused.
        :param connection: Unused.
        :param target: The metric or column that was updated.
        """
        session = inspect(target).session  # pylint: disable=disallowed-name

        # Forces an update to the table's changed_on value when a metric or column on the
        # table is updated. This busts the cache key for all charts that use the table.
        session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))

    @staticmethod
    def after_insert(
        mapper: Mapper,
        connection: Connection,
        target: SqlaTable,
    ) -> None:
        """
        Update dataset permissions after insert
        """
        target.load_database()
        security_manager.dataset_after_insert(mapper, connection, target)

    @staticmethod
    def after_delete(
        mapper: Mapper,
        connection: Connection,
        sqla_table: SqlaTable,
    ) -> None:
        """
        Update dataset permissions after delete
        """
        security_manager.dataset_after_delete(mapper, connection, sqla_table)

    def load_database(self: SqlaTable) -> None:
        # somehow the database attribute is not loaded on access
        if self.database_id and (
            not self.database or self.database.id != self.database_id
        ):
            session = inspect(self).session  # pylint: disable=disallowed-name
            self.database = session.query(Database).filter_by(id=self.database_id).one()


sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)

RLSFilterRoles = DBTable(
    "rls_filter_roles",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("role_id", Integer, ForeignKey("ab_role.id"), nullable=False),
    Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)

RLSFilterTables = DBTable(
    "rls_filter_tables",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("table_id", Integer, ForeignKey("tables.id")),
    Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)


class RowLevelSecurityFilter(Model, AuditMixinNullable):
    """
    Custom where clauses attached to Tables and Roles.
    """

    __tablename__ = "row_level_security_filters"
    id = Column(Integer, primary_key=True)
    name = Column(String(255), unique=True, nullable=False)
    description = Column(Text)
    filter_type = Column(
        Enum(
            *[filter_type.value for filter_type in utils.RowLevelSecurityFilterType],
            name="filter_type_enum",
        ),
    )
    group_key = Column(String(255), nullable=True)
    roles = relationship(
        security_manager.role_model,
        secondary=RLSFilterRoles,
        backref="row_level_security_filters",
    )
    tables = relationship(
        SqlaTable,
        overlaps="table",
        secondary=RLSFilterTables,
        backref="row_level_security_filters",
    )
    clause = Column(MediumText(), nullable=False)