superset/migrations/shared/migrate_viz/processors.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from superset.utils.core import as_list
from .base import MigrateViz
class MigrateTreeMap(MigrateViz):
source_viz_type = "treemap"
target_viz_type = "treemap_v2"
remove_keys = {"metrics"}
rename_keys = {"order_desc": "sort_by_metric"}
def _pre_action(self) -> None:
if (
"metrics" in self.data
and isinstance(self.data["metrics"], list)
and len(self.data["metrics"]) > 0
):
self.data["metric"] = self.data["metrics"][0]
class MigratePivotTable(MigrateViz):
source_viz_type = "pivot_table"
target_viz_type = "pivot_table_v2"
remove_keys = {"pivot_margins"}
rename_keys = {
"columns": "groupbyColumns",
"combine_metric": "combineMetric",
"groupby": "groupbyRows",
"number_format": "valueFormat",
"pandas_aggfunc": "aggregateFunction",
"row_limit": "series_limit",
"timeseries_limit_metric": "series_limit_metric",
"transpose_pivot": "transposePivot",
}
aggregation_mapping = {
"sum": "Sum",
"mean": "Average",
"median": "Median",
"min": "Minimum",
"max": "Maximum",
"std": "Sample Standard Deviation",
"var": "Sample Variance",
}
def _pre_action(self) -> None:
if pivot_margins := self.data.get("pivot_margins"):
self.data["colTotals"] = pivot_margins
self.data["colSubTotals"] = pivot_margins
if pandas_aggfunc := self.data.get("pandas_aggfunc"):
self.data["pandas_aggfunc"] = self.aggregation_mapping[pandas_aggfunc]
self.data["rowOrder"] = "value_z_to_a"
class MigrateDualLine(MigrateViz):
has_x_axis_control = True
source_viz_type = "dual_line"
target_viz_type = "mixed_timeseries"
rename_keys = {
"x_axis_format": "x_axis_time_format",
"y_axis_2_format": "y_axis_format_secondary",
"y_axis_2_bounds": "y_axis_bounds_secondary",
}
remove_keys = {"metric", "metric_2"}
def _pre_action(self) -> None:
self.data["yAxisIndex"] = 0
self.data["yAxisIndexB"] = 1
self.data["adhoc_filters_b"] = self.data.get("adhoc_filters")
self.data["truncateYAxis"] = True
self.data["metrics"] = [self.data.get("metric")]
self.data["metrics_b"] = [self.data.get("metric_2")]
def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
super()._migrate_temporal_filter(rv_data)
rv_data["adhoc_filters_b"] = rv_data.get("adhoc_filters") or []
class MigrateSunburst(MigrateViz):
source_viz_type = "sunburst"
target_viz_type = "sunburst_v2"
rename_keys = {"groupby": "columns"}
class TimeseriesChart(MigrateViz):
has_x_axis_control = True
rename_keys = {
"bottom_margin": "x_axis_title_margin",
"left_margin": "y_axis_title_margin",
"show_controls": "show_extra_controls",
"x_axis_label": "x_axis_title",
"x_axis_format": "x_axis_time_format",
"x_axis_showminmax": "truncateXAxis",
"x_ticks_layout": "xAxisLabelRotation",
"y_axis_label": "y_axis_title",
"y_axis_showminmax": "truncateYAxis",
"y_log_scale": "logAxis",
}
remove_keys = {
"contribution",
"line_interpolation",
"reduce_x_ticks",
"show_brush",
"show_markers",
}
def _pre_action(self) -> None:
self.data["contributionMode"] = "row" if self.data.get("contribution") else None
self.data["zoomable"] = self.data.get("show_brush") == "yes"
self.data["markerEnabled"] = self.data.get("show_markers") or False
self.data["y_axis_showminmax"] = True
bottom_margin = self.data.get("bottom_margin")
if self.data.get("x_axis_label") and (
not bottom_margin or bottom_margin == "auto"
):
self.data["bottom_margin"] = 30
left_margin = self.data.get("left_margin")
if self.data.get("y_axis_label") and (not left_margin or left_margin == "auto"):
self.data["left_margin"] = 30
if (rolling_type := self.data.get("rolling_type")) and rolling_type != "None":
self.data["rolling_type"] = rolling_type
if time_compare := self.data.get("time_compare"):
self.data["time_compare"] = [
value + " ago" for value in as_list(time_compare) if value
]
comparison_type = self.data.get("comparison_type") or "values"
self.data["comparison_type"] = (
"difference" if comparison_type == "absolute" else comparison_type
)
if x_ticks_layout := self.data.get("x_ticks_layout"):
self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0
class MigrateLineChart(TimeseriesChart):
source_viz_type = "line"
target_viz_type = "echarts_timeseries_line"
def _pre_action(self) -> None:
super()._pre_action()
line_interpolation = self.data.get("line_interpolation")
if line_interpolation == "cardinal":
self.target_viz_type = "echarts_timeseries_smooth"
elif line_interpolation == "step-before":
self.target_viz_type = "echarts_timeseries_step"
self.data["seriesType"] = "start"
elif line_interpolation == "step-after":
self.target_viz_type = "echarts_timeseries_step"
self.data["seriesType"] = "end"
class MigrateAreaChart(TimeseriesChart):
source_viz_type = "area"
target_viz_type = "echarts_area"
stacked_map = {
"expand": "Expand",
"stack": "Stack",
"stream": "Stream",
}
def _pre_action(self) -> None:
super()._pre_action()
self.remove_keys.add("stacked_style")
self.data["stack"] = self.stacked_map.get(
self.data.get("stacked_style") or "stack"
)
self.data["opacity"] = 0.7
class MigrateBarChart(TimeseriesChart):
source_viz_type = "bar"
target_viz_type = "echarts_timeseries_bar"
def _pre_action(self) -> None:
super()._pre_action()
self.rename_keys["show_bar_value"] = "show_value"
self.remove_keys.add("bar_stacked")
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
class MigrateDistBarChart(TimeseriesChart):
source_viz_type = "dist_bar"
target_viz_type = "echarts_timeseries_bar"
has_x_axis_control = False
def _pre_action(self) -> None:
super()._pre_action()
groupby = self.data.get("groupby") or []
columns = self.data.get("columns") or []
if len(groupby) > 0:
# x-axis supports only one value
self.data["x_axis"] = groupby[0]
self.data["groupby"] = []
if len(groupby) > 1:
# rest of groupby will go into dimensions
self.data["groupby"] += groupby[1:]
if len(columns) > 0:
self.data["groupby"] += columns
self.rename_keys["show_bar_value"] = "show_value"
self.remove_keys.add("columns")
self.remove_keys.add("bar_stacked")
self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None
class MigrateBubbleChart(MigrateViz):
source_viz_type = "bubble"
target_viz_type = "bubble_v2"
rename_keys = {
"bottom_margin": "x_axis_title_margin",
"left_margin": "y_axis_title_margin",
"limit": "row_limit",
"x_axis_format": "xAxisFormat",
"x_log_scale": "logXAxis",
"x_ticks_layout": "xAxisLabelRotation",
"y_axis_showminmax": "truncateYAxis",
"y_log_scale": "logYAxis",
}
remove_keys = {"x_axis_showminmax"}
def _pre_action(self) -> None:
bottom_margin = self.data.get("bottom_margin")
if self.data.get("x_axis_label") and (
not bottom_margin or bottom_margin == "auto"
):
self.data["bottom_margin"] = 30
if x_ticks_layout := self.data.get("x_ticks_layout"):
self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0
# Truncate y-axis by default to preserve layout
self.data["y_axis_showminmax"] = True
class MigrateHeatmapChart(MigrateViz):
source_viz_type = "heatmap"
target_viz_type = "heatmap_v2"
rename_keys = {
"all_columns_x": "x_axis",
"all_columns_y": "groupby",
"y_axis_bounds": "value_bounds",
"show_perc": "show_percentage",
}
remove_keys = {"sort_by_metric", "canvas_image_rendering"}
def _pre_action(self) -> None:
self.data["legend_type"] = "continuous"
class MigrateHistogramChart(MigrateViz):
source_viz_type = "histogram"
target_viz_type = "histogram_v2"
rename_keys = {
"x_axis_label": "x_axis_title",
"y_axis_label": "y_axis_title",
"normalized": "normalize",
}
remove_keys = {"all_columns_x", "link_length", "queryFields"}
def _pre_action(self) -> None:
all_columns_x = self.data.get("all_columns_x")
if all_columns_x and len(all_columns_x) > 0:
self.data["column"] = all_columns_x[0]
link_length = self.data.get("link_length")
self.data["bins"] = int(link_length) if link_length else 5
groupby = self.data.get("groupby")
if not groupby:
self.data["groupby"] = []
class MigrateSankey(MigrateViz):
source_viz_type = "sankey"
target_viz_type = "sankey_v2"
remove_keys = {"groupby"}
def _pre_action(self) -> None:
groupby = self.data.get("groupby")
if groupby and len(groupby) > 1:
self.data["source"] = groupby[0]
self.data["target"] = groupby[1]