superset/common/query_context_factory.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.daos.chart import ChartDAO
from superset.daos.datasource import DatasourceDAO
from superset.models.slice import Slice
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource
config = app.config
def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, DatasourceDAO())
class QueryContextFactory: # pylint: disable=too-few-public-methods
_query_object_factory: QueryObjectFactory
def __init__(self) -> None:
self._query_object_factory = create_query_object_factory()
def create( # pylint: disable=too-many-arguments
self,
*,
datasource: DatasourceDict,
queries: list[dict[str, Any]],
form_data: dict[str, Any] | None = None,
result_type: ChartDataResultType | None = None,
result_format: ChartDataResultFormat | None = None,
force: bool = False,
custom_cache_timeout: int | None = None,
) -> QueryContext:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
slice_ = None
if form_data and form_data.get("slice_id") is not None:
slice_ = self._get_slice(form_data.get("slice_id"))
result_type = result_type or ChartDataResultType.FULL
result_format = result_format or ChartDataResultFormat.JSON
queries_ = [
self._process_query_object(
datasource_model_instance,
form_data,
self._query_object_factory.create(
result_type, datasource=datasource, **query_obj
),
)
for query_obj in queries
]
cache_values = {
"datasource": datasource,
"queries": queries,
"result_type": result_type,
"result_format": result_format,
}
return QueryContext(
datasource=datasource_model_instance,
queries=queries_,
slice_=slice_,
form_data=form_data,
result_type=result_type,
result_format=result_format,
force=force,
custom_cache_timeout=custom_cache_timeout,
cache_values=cache_values,
)
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
)
def _get_slice(self, slice_id: Any) -> Slice | None:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
self._apply_filters(query_object)
return query_object
def _apply_granularity(
self,
query_object: QueryObject,
form_data: dict[str, Any] | None,
datasource: BaseDatasource,
) -> None:
temporal_columns = {
column["column_name"] if isinstance(column, dict) else column.column_name
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
x_axis = form_data and form_data.get("x_axis")
if granularity := query_object.granularity:
filter_to_remove = None
if is_adhoc_column(x_axis): # type: ignore
x_axis = x_axis.get("sqlExpression")
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
x_axis_column = next(
(
column
for column in query_object.columns
if column == x_axis
or (
isinstance(column, dict)
and column["sqlExpression"] == x_axis
)
),
None,
)
# Replaces x-axis column values with granularity
if x_axis_column:
if isinstance(x_axis_column, dict):
x_axis_column["sqlExpression"] = granularity
x_axis_column["label"] = granularity
else:
query_object.columns = [
granularity if column == x_axis_column else column
for column in query_object.columns
]
for post_processing in query_object.post_processing:
if post_processing.get("operation") == "pivot":
post_processing["options"]["index"] = [granularity]
# If no temporal x-axis, then get the default temporal filter
if not filter_to_remove:
temporal_filters = [
filter["col"]
for filter in query_object.filter
if filter["op"] == "TEMPORAL_RANGE"
]
if len(temporal_filters) > 0:
# Use granularity if it's already in the filters
if granularity in temporal_filters:
filter_to_remove = granularity
else:
# Use the first temporal filter
filter_to_remove = temporal_filters[0]
# Removes the temporal filter which may be an x-axis or
# another temporal filter. A new filter based on the value of
# the granularity will be added later in the code.
# In practice, this is replacing the previous default temporal filter.
if is_adhoc_column(filter_to_remove): # type: ignore
filter_to_remove = filter_to_remove.get("sqlExpression")
if filter_to_remove:
query_object.filter = [
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
]
def _apply_filters(self, query_object: QueryObject) -> None:
if query_object.time_range:
for filter_object in query_object.filter:
if filter_object["op"] == "TEMPORAL_RANGE":
filter_object["val"] = query_object.time_range