superset/cli/test_db.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import sys
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable
import click
import yaml
from rich.console import Console
from sqlalchemy import (
Column,
create_engine,
DateTime,
ForeignKey,
insert,
Integer,
MetaData,
select,
String,
Table,
)
from sqlalchemy.engine import Engine
from sqlalchemy.exc import NoSuchModuleError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.lib import (
ADVANCED_FEATURES,
BASIC_FEATURES,
DATABASE_DETAILS,
diagnose,
LIMIT_METHODS,
NICE_TO_HAVE_FEATURES,
)
metadata_obj = MetaData()
user = Table(
"tmp_superset_test_table_user",
metadata_obj,
Column("user_id", Integer, primary_key=True),
Column("user_name", String(16), nullable=False),
Column("email_address", String(60), key="email"),
Column("nickname", String(50), nullable=False),
)
user_prefs = Table(
"tmp_superset_test_table_user_prefs",
metadata_obj,
Column("pref_id", Integer, primary_key=True),
Column(
"user_id",
Integer,
ForeignKey("tmp_superset_test_table_user.user_id"),
nullable=False,
),
Column("pref_name", String(40), nullable=False),
Column("pref_value", String(100)),
)
TestType = Callable[[Console, Engine], None]
class TestRegistry:
def __init__(self) -> None:
self.tests: dict[str, Any] = defaultdict(list)
def add(self, *dialects: str) -> Callable[[TestType], TestType]:
def decorator(func: TestType) -> TestType:
for dialect in dialects:
self.tests[dialect].append(func)
return func
return decorator
def get_tests(self, dialect: str) -> list[TestType]:
return self.tests[dialect]
registry = TestRegistry()
@registry.add("sqlite", "postgresql")
def test_datetime(console: Console, engine: Engine) -> None:
"""
Create a table with a timestamp column and read value back.
"""
md = MetaData()
table = Table(
"test",
md,
Column("ts", DateTime),
)
console.print("Creating a table with a timestamp column...")
md.create_all(engine)
console.print("[green]Table created!")
now = datetime.now()
console.print("Inserting timestamp value...")
insert_stmt = insert(table).values(ts=now)
engine.execute(insert_stmt)
console.print("Reading timestamp value...")
select_stmt = select(table)
row = engine.execute(select_stmt).fetchone()
assert row[0] == now
console.print(":thumbs_up: [green]Success!")
@click.command()
@click.argument("sqlalchemy_uri")
@click.option(
"--connect-args",
"-c",
"raw_engine_kwargs",
help="Connect args as JSON or YAML",
)
def test_db(sqlalchemy_uri: str, raw_engine_kwargs: str | None = None) -> None:
"""
Run a series of tests against an analytical database.
This command tests:
1. The Superset DB engine spec.
2. The SQLAlchemy dialect.
3. The database connectivity and performance.
It's useful for people developing DB engine specs and/or SQLAlchemy dialects, and
also to test new versions of DB API 2.0 drivers.
TODO:
- implement SSH tunneling
- implement server certificates
"""
console = Console()
console.clear()
console.print("[bold]Collecting additional connection information...")
engine_kwargs = collect_connection_info(console, sqlalchemy_uri, raw_engine_kwargs)
console.print("[bold]\nChecking for a DB engine spec...")
test_db_engine_spec(console, sqlalchemy_uri)
console.print("[bold]\nTesting the SQLAlchemy dialect...")
engine = test_sqlalchemy_dialect(console, sqlalchemy_uri, engine_kwargs)
console.print("[bold]\nTesting the database connectivity...")
test_database_connectivity(console, engine)
def collect_connection_info(
console: Console,
sqlalchemy_uri: str,
raw_engine_kwargs: str | None = None,
) -> dict[str, Any]:
"""
Collect ``engine_kwargs`` if needed.
"""
console.print(f"[green]SQLAlchemy URI: [bold]{sqlalchemy_uri}")
if raw_engine_kwargs is None:
configure_engine_kwargs = input(
"> Do you want to configure connection arguments? [y/N] "
)
if configure_engine_kwargs.strip().lower() == "y":
console.print(
"Please paste the engine_kwargs as JSON or YAML and press CTRL-D when "
"finished"
)
raw_engine_kwargs = sys.stdin.read()
else:
raw_engine_kwargs = "{}"
return yaml.safe_load(raw_engine_kwargs)
def test_db_engine_spec(
console: Console,
sqlalchemy_uri: str,
) -> type[BaseEngineSpec] | None:
"""
Test the DB engine spec, if available.
"""
spec: type[BaseEngineSpec] | None = None
for spec in load_engine_specs():
try:
supported = spec.supports_url(make_url_safe(sqlalchemy_uri))
except NoSuchModuleError:
console.print("[red]No SQLAlchemy dialect found for the URI!")
console.print("[bold]Exiting...")
sys.exit(1)
if supported:
if spec.__module__.startswith("superset.db_engine_specs"):
console.print(
f":thumbs_up: [green]Found DB engine spec: [bold]{spec.engine_name}"
)
else:
console.print(
":warning: [yellow]Found 3rd party DB engine spec: "
f"[bold]{spec.engine_name} ({spec.__module__})"
)
break
else:
console.print(
":thumbs_down: [red]No DB engine spec found for the SQLAlchemy URI. The "
"database can still be used with Superset, but some functionality may be "
"limited."
)
if spec is None:
return None
info = diagnose(spec)
console.print("About the database:")
console.print(" - Method used to apply LIMIT to queries:", info["limit_method"])
for k, v in LIMIT_METHODS.items():
console.print(f" - {k}: {v}")
for key, feature in DATABASE_DETAILS.items():
console.print(f" - {feature}:", info[key])
console.print("[bold]Checking for basic features...")
console.print("Supported time grains:")
for k, v in info["time_grains"].items():
score = " (+1)" if v else ""
console.print(f" - {k}: {v}{score}")
for k, v in BASIC_FEATURES.items():
score = " (+10)" if info[k] else ""
console.print(f"{v}: {info[k]}{score}")
console.print("[bold]Checking for nice-to-have features...")
for k, v in NICE_TO_HAVE_FEATURES.items():
score = " (+10)" if info[k] else ""
console.print(f"{v}: {info[k]}{score}")
console.print("[bold]Checking for advanced features...")
for k, v in ADVANCED_FEATURES.items():
score = " (+10)" if info[k] else ""
console.print(f"{v}: {info[k]}{score}")
# pylint: disable=consider-using-f-string
console.print("[bold]Overall score: {score}/{max_score}".format(**info))
return spec
def test_sqlalchemy_dialect(
console: Console,
sqlalchemy_uri: str,
engine_kwargs: dict[str, Any],
) -> Engine:
"""
Test the SQLAlchemy dialect, making sure it supports everything Superset needs.
"""
engine = create_engine(sqlalchemy_uri, **engine_kwargs)
dialect = engine.dialect
console.print("[bold]Checking functions used by the inspector...")
keys = [
"get_schema_names",
"get_table_names",
"get_view_names",
"get_indexes",
"get_table_comment",
"get_columns",
"get_unique_constraints",
"get_check_constraints",
"get_pk_constraint",
"get_foreign_keys",
]
for key in keys:
console.print(f" - {key}:", hasattr(dialect, key))
console.print("[bold]Checking dialect attributes...")
if hasattr(dialect, "dbapi"):
console.print(f" - dbapi: [bold]{dialect.dbapi.__name__}")
else:
console.print(" - dbapi:", None)
attrs = [
"name",
"driver",
"supports_multivalues_insert",
]
for attr in attrs:
console.print(f" - {attr}:", getattr(dialect, attr, None))
console.print("Supports do_ping:", hasattr(dialect, "do_ping"))
console.print(
"Can quote identifiers:",
hasattr(dialect, "identifier_preparer")
and hasattr(dialect.identifier_preparer, "quote"),
)
console.print(
"Doesn't require name normalization:",
not dialect.requires_name_normalize,
)
if dialect.requires_name_normalize:
console.print(
" - Implements denormalize_name:", hasattr(dialect, "denormalize_name")
)
return engine
def test_database_connectivity(console: Console, engine: Engine) -> None:
"""
Tests the DB API 2.0 driver.
"""
with console.status("[bold green]Connecting to database..."):
try:
conn = engine.raw_connection()
engine.dialect.do_ping(conn)
console.print(":thumbs_up: [green]Connected successfully!")
except Exception as ex: # pylint: disable=broad-except
console.print(f":thumbs_down: [red]Failed to connect: {ex}")
console.print("[bold]Exiting...")
sys.exit(1)
cursor = conn.cursor()
console.print("[bold]Checking that we can run queries...")
console.print("sql> SELECT 1;")
cursor.execute("SELECT 1")
result = cursor.fetchone()[0]
color = "green" if result == 1 else "red"
console.print(f"[{color}]> {result}")
# run engine-specific tests
if tests := registry.get_tests(engine.dialect.name):
console.print("[bold]Running engine-specific tests...")
for test in tests:
docstring = (test.__doc__ or test.__name__).strip().splitlines()[0]
try:
console.print(f"[bold]{docstring}...")
test(console, engine)
except Exception as ex: # pylint: disable=broad-except
console.print(f"[red]Test failed: {ex}")
console.print("[bold]Exiting...")
sys.exit(1)