dbt_airflow_factory/tasks_builder/builder.py
"""Class parsing ``manifest.json`` into Airflow tasks."""
import json
import logging
from typing import Any, ContextManager, Dict, Tuple
from airflow.models.baseoperator import BaseOperator
from airflow.operators.dummy import DummyOperator
from airflow.sensors.external_task_sensor import ExternalTaskSensor
from dbt_airflow_factory.constants import IS_FIRST_AIRFLOW_VERSION
from dbt_airflow_factory.tasks_builder.parameters import TasksBuildingParameters
if not IS_FIRST_AIRFLOW_VERSION:
from airflow.utils.task_group import TaskGroup
from dbt_graph_builder.builder import GraphConfiguration, create_tasks_graph
from dbt_graph_builder.gateway import GatewayConfiguration
from dbt_graph_builder.graph import DbtManifestGraph
from dbt_graph_builder.node_type import NodeType
from dbt_airflow_factory.operator import DbtRunOperatorBuilder, EphemeralOperator
from dbt_airflow_factory.tasks import ModelExecutionTask, ModelExecutionTasks
class DbtAirflowTasksBuilder:
"""
Parses ``manifest.json`` into Airflow tasks.
:param airflow_config: DBT node operator.
:type airflow_config: TasksBuildingParameters
:param operator_builder: DBT node operator.
:type operator_builder: DbtRunOperatorBuilder
:param gateway_config: DBT node operator.
:type gateway_config: GatewayConfiguration
"""
def __init__(
self,
airflow_config: TasksBuildingParameters,
operator_builder: DbtRunOperatorBuilder,
gateway_config: GatewayConfiguration,
):
self.operator_builder = operator_builder
self.airflow_config = airflow_config
self.gateway_config = gateway_config
def parse_manifest_into_tasks(self, manifest_path: str) -> ModelExecutionTasks:
"""
Parse ``manifest.json`` into tasks.
:param manifest_path: Path to ``manifest.json``.
:type manifest_path: str
:return: Dictionary of tasks created from ``manifest.json`` parsing.
:rtype: ModelExecutionTasks
"""
return self._make_dbt_tasks(manifest_path)
def create_seed_task(self) -> BaseOperator:
"""
Create ``dbt_seed`` task.
:return: Operator for ``dbt_seed`` task.
:rtype: BaseOperator
"""
return self.operator_builder.create("dbt_seed", "seed")
@staticmethod
def _load_dbt_manifest(manifest_path: str) -> dict:
with open(manifest_path, "r") as f:
manifest_content = json.load(f)
logging.debug("Manifest content: " + str(manifest_content))
return manifest_content
def _make_dbt_test_task(self, model_name: str, is_in_task_group: bool) -> BaseOperator:
command = "test"
return self.operator_builder.create(
self._build_task_name(model_name, command, is_in_task_group),
command,
model_name,
additional_dbt_args=["--indirect-selection=cautious"],
)
def _make_dbt_multiple_deps_test_task(
self, test_names: str, dependency_tuple_str: str
) -> BaseOperator:
command = "test"
return self.operator_builder.create(dependency_tuple_str, command, test_names)
def _make_dbt_run_task(self, model_name: str, is_in_task_group: bool) -> BaseOperator:
command = "run"
return self.operator_builder.create(
self._build_task_name(model_name, command, is_in_task_group),
command,
model_name,
)
@staticmethod
def _build_task_name(model_name: str, command: str, is_in_task_group: bool) -> str:
return command if is_in_task_group else f"{model_name}_{command}"
@staticmethod
def _create_task_group_for_model(
model_name: str, use_task_group: bool
) -> Tuple[Any, ContextManager]:
import contextlib
task_group = (
None
if (IS_FIRST_AIRFLOW_VERSION or not use_task_group)
else TaskGroup(group_id=model_name)
)
task_group_ctx = task_group or contextlib.nullcontext()
return task_group, task_group_ctx
def _create_task_for_model(
self,
model_name: str,
use_task_group: bool,
) -> ModelExecutionTask:
(task_group, task_group_ctx) = self._create_task_group_for_model(model_name, use_task_group)
is_in_task_group = task_group is not None
with task_group_ctx:
run_task = self._make_dbt_run_task(model_name, is_in_task_group)
test_task = self._make_dbt_test_task(model_name, is_in_task_group)
# noinspection PyStatementEffect
run_task >> test_task
return ModelExecutionTask(run_task, test_task, task_group)
def _create_task_from_graph_node(
self, node_name: str, node: Dict[str, Any]
) -> ModelExecutionTask:
if node["node_type"] == NodeType.MULTIPLE_DEPS_TEST:
return ModelExecutionTask(
self._make_dbt_multiple_deps_test_task(node["select"], node_name), None
)
if node["node_type"] == NodeType.SOURCE_SENSOR:
return self._create_dag_sensor(node)
if node["node_type"] == NodeType.MOCK_GATEWAY:
return self._create_dummy_task(node)
if node["node_type"] == NodeType.EPHEMERAL:
return ModelExecutionTask(
EphemeralOperator(task_id=f"{node['select']}__ephemeral"), None
)
return self._create_task_for_model(
node["select"],
self.airflow_config.use_task_group,
)
def _create_tasks_from_graph(self, dbt_airflow_graph: DbtManifestGraph) -> ModelExecutionTasks:
result_tasks = {
node_name: self._create_task_from_graph_node(node_name, node)
for node_name, node in dbt_airflow_graph.get_graph_nodes()
}
for node, neighbour in dbt_airflow_graph.get_graph_edges():
# noinspection PyStatementEffect
(result_tasks[node].get_end_task() >> result_tasks[neighbour].get_start_task())
return ModelExecutionTasks(
result_tasks,
dbt_airflow_graph.get_graph_sources(),
dbt_airflow_graph.get_graph_sinks(),
)
def _make_dbt_tasks(self, manifest_path: str) -> ModelExecutionTasks:
manifest = self._load_dbt_manifest(manifest_path)
dbt_airflow_graph: DbtManifestGraph = create_tasks_graph(
manifest,
GraphConfiguration(
gateway_config=self.gateway_config,
enable_dags_dependencies=self.airflow_config.enable_dags_dependencies,
show_ephemeral_models=self.airflow_config.show_ephemeral_models,
check_all_deps_for_multiple_deps_tests=self.airflow_config.check_all_deps_for_multiple_deps_tests,
),
)
tasks_with_context = self._create_tasks_from_graph(dbt_airflow_graph)
logging.info(f"Created {str(tasks_with_context.length())} tasks groups")
return tasks_with_context
def _create_dag_sensor(self, node: Dict[str, Any]) -> ModelExecutionTask:
# todo move parameters to configuration
return ModelExecutionTask(
ExternalTaskSensor(
task_id="sensor_" + node["select"],
external_dag_id=node["dag"],
external_task_id=node["select"]
+ (".test" if self.airflow_config.use_task_group else "_test"),
timeout=24 * 60 * 60,
allowed_states=["success"],
failed_states=["failed", "skipped"],
mode="reschedule",
)
)
@staticmethod
def _create_dummy_task(node: Dict[str, Any]) -> ModelExecutionTask:
return ModelExecutionTask(DummyOperator(task_id=node["select"]))