superset/sql_validators/presto_db.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import time
from contextlib import closing
from typing import Any
from superset import app
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
MAX_ERROR_ROWS = 10
config = app.config
logger = logging.getLogger(__name__)
class PrestoSQLValidationError(Exception):
"""Error in the process of asking Presto to validate SQL querytext"""
class PrestoDBSQLValidator(BaseSQLValidator):
"""Validate SQL queries using Presto's built-in EXPLAIN subtype"""
name = "PrestoDBSQLValidator"
@classmethod
def validate_statement(
cls,
statement: str,
database: Database,
cursor: Any,
) -> SQLValidationAnnotation | None:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = database.mutate_sql_based_on_config(sql)
# Transform the final statement to an explain call before sending it on
# to presto to validate
sql = f"EXPLAIN (TYPE VALIDATE) {sql}"
# Invoke the query against presto. NB this deliberately doesn't use the
# engine spec's handle_cursor implementation since we don't record
# these EXPLAIN queries done in validation as proper Query objects
# in the superset ORM.
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError
try:
db_engine_spec.execute(cursor, sql, database)
polled = cursor.poll()
while polled:
logger.info("polling presto for validation progress")
stats = polled.get("stats", {})
if stats:
state = stats.get("state")
if state == "FINISHED":
break
time.sleep(0.2)
polled = cursor.poll()
db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
return None
except DatabaseError as db_error:
# The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
# as though they were normal queries. In other words, it doesn't
# know that errors here are not exceptional. To map this back to
# ordinary control flow, we have to trap the category of exception
# raised by the underlying client, match the exception arguments
# pyhive provides against the shape of dictionary for a presto query
# invalid error, and restructure that error as an annotation we can
# return up.
# If the first element in the DatabaseError is not a dictionary, but
# is a string, return that message.
if db_error.args and isinstance(db_error.args[0], str):
raise PrestoSQLValidationError(db_error.args[0]) from db_error
# Confirm the first element in the DatabaseError constructor is a
# dictionary with error information. This is currently provided by
# the pyhive client, but may break if their interface changes when
# we update at some point in the future.
if not db_error.args or not isinstance(db_error.args[0], dict):
raise PrestoSQLValidationError(
"The pyhive presto client returned an unhandled database error."
) from db_error
error_args: dict[str, Any] = db_error.args[0]
# Confirm the two fields we need to be able to present an annotation
# are present in the error response -- a message, and a location.
if "message" not in error_args:
raise PrestoSQLValidationError(
"The pyhive presto client did not report an error message"
) from db_error
if "errorLocation" not in error_args:
# Pylint is confused about the type of error_args, despite the hints
# and checks above.
message = error_args["message"] + "\n(Error location unknown)"
# If we have a message but no error location, return the message and
# set the location as the beginning.
return SQLValidationAnnotation(
message=message, line_number=1, start_column=1, end_column=1
)
message = error_args["message"]
err_loc = error_args["errorLocation"]
line_number = err_loc.get("lineNumber", None)
start_column = err_loc.get("columnNumber", None)
end_column = err_loc.get("columnNumber", None)
return SQLValidationAnnotation(
message=message,
line_number=line_number,
start_column=start_column,
end_column=end_column,
)
except Exception as ex:
logger.exception("Unexpected error running validation query: %s", str(ex))
raise
@classmethod
def validate(
cls,
sql: str,
catalog: str | None,
schema: str | None,
database: Database,
) -> list[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
prepended explain.
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine(
catalog=catalog,
schema=schema,
source=QuerySource.SQL_LAB,
) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: list[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))
return annotations