superset/db_engine_specs/databricks.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from superset.constants import TimeGrain, USER_AGENT
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils import json
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
class DatabricksBaseSchema(Schema):
"""
Fields that are required for both Databricks drivers that uses a
dynamic form.
"""
access_token = fields.Str(required=True)
host = fields.Str(required=True)
port = fields.Integer(
required=True,
metadata={"description": __("Database port")},
validate=Range(min=0, max=2**16, max_inclusive=False),
)
encryption = fields.Boolean(
required=False,
metadata={"description": __("Use an encrypted connection to the database")},
)
class DatabricksBaseParametersType(TypedDict):
"""
The parameters are all the keys that do not exist on the Database model.
These are used to build the sqlalchemy uri.
"""
access_token: str
host: str
port: int
encryption: bool
class DatabricksNativeSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksNativeEngineSpec.
"""
database = fields.Str(required=True)
class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
"""
Properties required only for the DatabricksNativeEngineSpec.
"""
http_path = fields.Str(required=True)
class DatabricksNativeParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksNativeEngineSpec.
"""
database: str
class DatabricksNativePropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksNativeEngineSpec
in order tocreate a connection if the dynamic form is used.
"""
parameters: DatabricksNativeParametersType
extra: str
class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field = fields.Str(required=True)
default_catalog = fields.Str(required=True)
default_schema = fields.Str(required=True)
class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field: str
default_catalog: str
default_schema: str
class DatabricksPythonConnectorPropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksPythonConnectorEngineSpec
in order to create a connection if the dynamic form is used.
"""
parameters: DatabricksPythonConnectorParametersType
extra: str
time_grain_expressions: dict[str | None, str] = {
None: "{col}",
TimeGrain.SECOND: "date_trunc('second', {col})",
TimeGrain.MINUTE: "date_trunc('minute', {col})",
TimeGrain.HOUR: "date_trunc('hour', {col})",
TimeGrain.DAY: "date_trunc('day', {col})",
TimeGrain.WEEK: "date_trunc('week', {col})",
TimeGrain.MONTH: "date_trunc('month', {col})",
TimeGrain.QUARTER: "date_trunc('quarter', {col})",
TimeGrain.YEAR: "date_trunc('year', {col})",
TimeGrain.WEEK_ENDING_SATURDAY: (
"date_trunc('week', {col} + interval '1 day') + interval '5 days'"
),
TimeGrain.WEEK_STARTING_SUNDAY: (
"date_trunc('week', {col} + interval '1 day') - interval '1 day'"
),
}
class DatabricksHiveEngineSpec(HiveEngineSpec):
engine_name = "Databricks Interactive Cluster"
engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"
_show_functions_column = "function"
_time_grain_expressions = time_grain_expressions
class DatabricksBaseEngineSpec(BaseEngineSpec):
_time_grain_expressions = time_grain_expressions
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra)
@classmethod
def epoch_to_dttm(cls) -> str:
return HiveEngineSpec.epoch_to_dttm()
class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
engine_name = "Databricks SQL Endpoint"
engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
default_driver = ""
encryption_parameters = {"ssl": "1"}
required_parameters = {"access_token", "host", "port"}
context_key_mapping = {
"access_token": "password",
"host": "hostname",
"port": "port",
}
@staticmethod
def get_extra_params(database: Database) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
Trim whitespace from connect_args to avoid databricks driver errors
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
connect_args.setdefault("_user_agent_entry", USER_AGENT)
# trim whitespace from http_path to avoid databricks errors on connecting
if http_path := connect_args.get("http_path"):
connect_args["http_path"] = http_path.strip()
return extra
@classmethod
def get_table_names(
cls,
database: Database,
inspector: Inspector,
schema: str | None,
) -> set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@classmethod
def extract_errors(
cls, ex: Exception, context: dict[str, Any] | None = None
) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
# access_token isn't currently parseable from the
# databricks error response, but adding it in here
# for reference if their error message changes
for key, value in cls.context_key_mapping.items():
context[key] = context.get(value)
for regex, (message, error_type, extra) in cls.custom_errors.items():
match = regex.search(raw_message)
if match:
params = {**context, **match.groupdict()}
extra["engine_name"] = cls.engine_name
return [
SupersetError(
error_type=error_type,
message=message % params,
level=ErrorLevel.ERROR,
extra=extra,
)
]
return [
SupersetError(
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
message=cls._extract_error_message(ex),
level=ErrorLevel.ERROR,
extra={"engine_name": cls.engine_name},
)
]
@classmethod
def validate_parameters( # type: ignore
cls,
properties: Union[
DatabricksNativePropertiesType,
DatabricksPythonConnectorPropertiesType,
],
) -> list[SupersetError]:
errors: list[SupersetError] = []
if extra := json.loads(properties.get("extra")): # type: ignore
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
parameters = {
**properties,
**properties.get("parameters", {}),
}
if connect_args.get("http_path"):
parameters["http_path"] = connect_args.get("http_path")
present = {key for key in parameters if parameters.get(key, ())}
if missing := sorted(cls.required_parameters - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.WARNING,
extra={"missing": missing},
),
)
host = parameters.get("host", None)
if not host:
return errors
if not is_hostname_valid(host): # type: ignore
errors.append(
SupersetError(
message="The hostname provided can't be resolved.",
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["host"]},
),
)
return errors
port = parameters.get("port", None)
if not port:
return errors
try:
port = int(port) # type: ignore
except (ValueError, TypeError):
errors.append(
SupersetError(
message="Port must be a valid integer.",
error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
if not (isinstance(port, int) and 0 <= port < 2**16):
errors.append(
SupersetError(
message=(
"The port must be an integer between 0 and 65535 "
"(inclusive)."
),
error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
elif not is_port_open(host, port): # type: ignore
errors.append(
SupersetError(
message="The port is closed.",
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
return errors
class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
parameters_schema = DatabricksNativeSchema()
properties_schema = DatabricksNativePropertiesSchema()
sqlalchemy_uri_placeholder = (
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"database": "database",
"username": "username",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"database",
"extra",
}
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksNativeParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod
def parameters_json_schema(cls) -> Any:
"""
Return configuration parameters as OpenAPI.
"""
if not cls.properties_schema:
return None
spec = APISpec(
title="Database Parameters",
version="1.0.0",
openapi_version="3.0.2",
plugins=[MarshmallowPlugin()],
)
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]
@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
"""
Return the default catalog.
The default behavior for Databricks is confusing. When Unity Catalog is not
enabled we have (the DB engine spec hasn't been tested with it enabled):
> SHOW CATALOGS;
spark_catalog
> SELECT current_catalog();
hive_metastore
To handle permissions correctly we use the result of `SHOW CATALOGS` when a
single catalog is returned.
"""
with database.get_sqla_engine() as engine:
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
if len(catalogs) == 1:
return catalogs.pop()
return engine.execute("SELECT current_catalog()").scalar()
@classmethod
def get_prequeries(
cls,
database: Database,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
prequeries = []
if catalog:
prequeries.append(f"USE CATALOG {catalog}")
if schema:
prequeries.append(f"USE SCHEMA {schema}")
return prequeries
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks Python Connector"
default_driver = "databricks-sql-python"
drivers = {"databricks-sql-python": "Databricks SQL Python"}
parameters_schema = DatabricksPythonConnectorSchema()
sqlalchemy_uri_placeholder = (
"databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
"&catalog={default_catalog}&schema={default_schema}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"default_catalog": "catalog",
"default_schema": "schema",
"http_path_field": "http_path",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"default_catalog",
"default_schema",
"http_path_field",
}
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
) -> str:
query = {}
if http_path := parameters.get("http_path_field"):
query["http_path"] = http_path
if catalog := parameters.get("default_catalog"):
query["catalog"] = catalog
if schema := parameters.get("default_schema"):
query["schema"] = schema
if parameters.get("encryption"):
query.update(cls.encryption_parameters)
return str(
URL.create(
cls.engine,
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_: Any, **__: Any
) -> DatabricksPythonConnectorParametersType:
url = make_url_safe(uri)
query = {
key: value
for (key, value) in url.query.items()
if (key, value) not in cls.encryption_parameters.items()
}
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"http_path_field": query["http_path"],
"default_catalog": query["catalog"],
"default_schema": query["schema"],
"encryption": encryption,
}
@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
return database.url_object.query.get("catalog")
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
uri = uri.update_query_dict({"catalog": catalog})
if schema:
uri = uri.update_query_dict({"schema": schema})
return uri, connect_args