getindata/kedro-airflow-k8s

View on GitHub
kedro_airflow_k8s/template.py

Summary

Maintainability
A
1 hr
Test Coverage
A
94%
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional

import jinja2
import kedro
from jinja2.environment import TemplateStream
from slugify import slugify

from kedro_airflow_k8s import version
from kedro_airflow_k8s.config import KubernetesPodTemplate, ResourceConfig


def get_commit_sha(context_helper):
    try:
        return context_helper.session.store["git"]["commit_sha"]
    except KeyError:
        return "UNKNOWN"


def get_mlflow_url(context_helper):
    try:
        import importlib

        importlib.import_module("kedro_mlflow")
        return context_helper.mlflow_config["mlflow_tracking_uri"]
    except (ModuleNotFoundError, kedro.config.config.MissingConfigException):
        return None


def _get_jinja_template(name: str):
    loader = jinja2.FileSystemLoader(str(Path(__file__).parent))
    jinja_env = jinja2.Environment(
        autoescape=True, loader=loader, lstrip_blocks=True
    )
    jinja_env.filters["slugify"] = slugify
    template = jinja_env.get_template(name)
    return template


def _node_resources(nodes, config) -> Dict[str, ResourceConfig]:
    result = {}
    default_config = config.__default__
    for node in nodes:
        resources = [
            tag[len("resources:") :]  # noqa: E203
            for tag in node.tags
            if "resources:" in tag
        ]
        result[node.name] = (
            config[resources[0]] if resources else default_config
        )
    return result


def _pod_templates(nodes, config) -> Dict[str, KubernetesPodTemplate]:
    result = defaultdict(lambda: None)

    default_config = config.__default__
    for node in nodes:
        templates = [
            tag[len("k8s_template:") :]  # noqa: E203
            for tag in node.tags
            if "k8s_template:" in tag
        ]
        result[node.name] = (
            config[templates[0]] if templates else default_config
        )
    return result


def _create_template_stream(
    context_helper,
    dag_name: str,
    schedule_interval: str,
    image: str,
    with_external_dependencies: bool,
) -> TemplateStream:
    template = _get_jinja_template("airflow_dag_template.j2")

    custom_spark_factory = None
    if context_helper.config.run_config.spark.operator_factory:
        factory = context_helper.config.run_config.spark.operator_factory
        pkg = factory[: factory.rindex(".")]
        clazz = factory[factory.rindex(".") + 1 :]  # noqa: E203
        mod = __import__(pkg, fromlist=[clazz])
        custom_spark_factory = getattr(mod, clazz)()

    return template.stream(
        custom_spark_factory=custom_spark_factory,
        pipeline=context_helper.pipeline,
        pipeline_grouped=context_helper.pipeline_grouped,
        with_external_dependencies=with_external_dependencies,
        config=context_helper.config,
        resources=_node_resources(
            context_helper.pipeline.nodes,
            context_helper.config.run_config.resources,
        ),
        mlflow_url=get_mlflow_url(context_helper),
        env=context_helper.env,
        project_name=context_helper.project_name,
        dag_name=dag_name,
        image=image,
        pipeline_name=context_helper.pipeline_name,
        schedule_interval=schedule_interval,
        git_info=context_helper.session.store["git"],
        kedro_airflow_k8s_version=version,
        include_start_mlflow_experiment_operator=(
            Path(__file__).parent / "operators/start_mlflow_experiment.py"
        ).read_text(),
        include_create_pipeline_storage_operator=(
            Path(__file__).parent / "operators/create_pipeline_storage.py"
        ).read_text(),
        include_delete_pipeline_storage_operator=(
            Path(__file__).parent / "operators/delete_pipeline_storage.py"
        ).read_text(),
        include_data_volume_init_operator=(
            Path(__file__).parent / "operators/data_volume_init.py"
        ).read_text(),
        include_node_pod_operator=(
            Path(__file__).parent / "operators/node_pod.py"
        ).read_text(),
        include_spark_submit_k8s_operator=(
            Path(__file__).parent / "operators/spark_submit_k8s.py"
        ).read_text(),
        secrets=context_helper.config.run_config.secrets,
        macro_params=context_helper.config.run_config.macro_params,
        variables_params=context_helper.config.run_config.variables_params,
        k8s_templates=_pod_templates(
            context_helper.pipeline.nodes,
            context_helper.config.run_config.kubernetes_pod_templates,
        ),
    )


def _create_spark_tasks_template_stream(
    context_helper,
) -> Dict[str, TemplateStream]:
    spark_task_templates = {}
    spark_task_groups = [
        tg
        for tg in context_helper.pipeline_grouped
        if tg.group_type == "pyspark"
    ]
    for tg in spark_task_groups:
        node_names = [node.name for node in tg.task_group]
        template = _get_jinja_template("airflow_spark_task_template.j2")
        spark_task_templates[tg.name] = template.stream(
            node_names=node_names,
            project_name=context_helper.project_name,
            kedro_env=context_helper.env,
        )
    return spark_task_templates


def get_cron_expression(
    ctx, cron_expression: Optional[str] = None
) -> Optional[str]:
    config = ctx.obj["context_helper"].config
    return cron_expression or config.run_config.cron_expression


def get_dag_filename_and_template_stream(
    ctx,
    cron_expression: Optional[str] = None,
    dag_name: Optional[str] = None,
    image: Optional[str] = None,
    with_external_dependencies: bool = True,
):
    config = ctx.obj["context_helper"].config
    dag_name = dag_name or config.run_config.run_name

    template_stream = _create_template_stream(
        ctx.obj["context_helper"],
        dag_name=dag_name,
        schedule_interval=cron_expression,
        image=image or config.run_config.image,
        with_external_dependencies=with_external_dependencies,
    )
    spark_template_streams = _create_spark_tasks_template_stream(
        ctx.obj["context_helper"],
    )
    return dag_name, template_stream, spark_template_streams