getindata/kedro-airflow-k8s

View on GitHub
kedro_airflow_k8s/airflow_dag_template.j2

Summary

Maintainability
Test Coverage
import logging

from airflow import DAG
from airflow.kubernetes.secret import Secret
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.operators.dummy import DummyOperator
from datetime import timedelta
from datetime import datetime

{% if config.run_config.spark.type == 'dataproc' %}
from airflow.providers.google.cloud.operators.dataproc import (
    DataprocSubmitJobOperator,
    DataprocCreateClusterOperator,
    DataprocDeleteClusterOperator
)
{% elif config.run_config.spark.type in ['kubernetes','k8s'] %}
{{ include_spark_submit_k8s_operator | safe }}
{% endif %}

EXPERIMENT_NAME = "{{ config.run_config.experiment_name | slugify }}"

{% if config.run_config.failure_handlers is defined and config.run_config.failure_handlers|length > 0 -%}
def task_fail_handler(context):
    {%- for handler in config.run_config.failure_handlers %}
    {%- if handler.type == "slack" %}
    from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator
    from airflow.hooks.base import BaseHook
    SlackWebhookOperator(
        task_id='failure_slack_notification',
        http_conn_id='{{ handler.connection_id }}',
        webhook_token=BaseHook.get_connection('{{ handler.connection_id }}').password,
        message="""{{ handler.message_template }}""".format(
            task=context.get('task_instance').task_id,
            dag=context.get('task_instance').dag_id,
            execution_time=context.get('execution_date'),
            url=context.get('task_instance').log_url),
        username=BaseHook.get_connection('{{ handler.connection_id }}').login
    ).execute(context=context)
    {% endif %}
    {% endfor %}
{% endif %}

args = {
    'owner': 'airflow',
{%- if config.run_config.failure_handlers is defined and config.run_config.failure_handlers|length > 0 %}
    'on_failure_callback': task_fail_handler,
{% endif %}
}

{% if mlflow_url %}
{{ include_start_mlflow_experiment_operator | safe }}
{% endif %}
{% if not config.run_config.volume.disabled %}
{{ include_create_pipeline_storage_operator | safe }}
{{ include_delete_pipeline_storage_operator | safe }}
{% if not config.run_config.volume.skip_init %}
{{ include_data_volume_init_operator | safe }}
{% endif %}
{% endif %}
{{ include_node_pod_operator | safe }}

with DAG(
    dag_id='{{ dag_name }}',
    description='{{ config.run_config.description }}',
    default_args=args,
    schedule_interval={% if schedule_interval %}'{{ schedule_interval }}'{% else %}None{% endif %},
    start_date={% if config.run_config.start_date %}datetime({{ config.run_config.start_date[0:4] }}, int('{{ config.run_config.start_date[4:6] }}'), int('{{ config.run_config.start_date[6:8] }}')){% else %}days_ago(2){% endif %},
    tags=['commit_sha:{{ git_info.commit_sha }}',
            'generated_with_kedro_airflow_k8s:{{ kedro_airflow_k8s_version }}',
            'experiment_name:'+EXPERIMENT_NAME],
    params={},
) as dag:

    pvc_name = '{{ project_name | safe | slugify }}.{% raw %}{{ ts_nodash | lower  }}{% endraw %}'

    {% if mlflow_url %}
    start_mlflow_run = StartMLflowExperimentOperator(
        experiment_name=EXPERIMENT_NAME,
        mlflow_url='{{ mlflow_url }}',
        image="{{ image }}",
        auth_handler={{ config.run_config.auth_config.type }}AuthHandler({%- if config.run_config.auth_config.params -%}[
        {%- for param in config.run_config.auth_config.params -%}
            "{{ param }}",
        {%- endfor -%}]
        {%- endif -%})
    )
    {% endif %}

    {% if not config.run_config.volume.disabled %}
    create_pipeline_storage = CreatePipelineStorageOperator(
        pvc_name=pvc_name,
        namespace='{{ config.run_config.namespace }}',
        access_modes={{config.run_config.volume.access_modes | safe}},
        volume_size='{{ config.run_config.volume.size }}',
        storage_class_name='{{ config.run_config.volume.storageclass }}'
    )

    delete_pipeline_storage=DeletePipelineStorageOperator(
        pvc_name=pvc_name,
        namespace='{{ config.run_config.namespace }}'
    )

    {% if not config.run_config.volume.skip_init %}
    data_volume_init=DataVolumeInitOperator(
        namespace='{{ config.run_config.namespace }}',
        pvc_name=pvc_name,
        image='{{ image }}',
        image_pull_policy='{{ config.run_config.image_pull_policy }}',
        startup_timeout={{ config.run_config.startup_timeout }},
        volume_owner={{ config.run_config.volume.owner }},
        {%- if config.run_config.image_pull_secrets %}
        image_pull_secrets="{{ config.run_config.image_pull_secrets }}",
        {%- endif %}
        {%- if config.run_config.service_account_name %}
        service_account_name="{{ config.run_config.service_account_name }}",
        {%- endif %}

    )
    {% endif %}
    {% endif %}

    {%- if custom_spark_factory %}
    {{ custom_spark_factory.imports_statement | indent() }}
    {%- endif %}

    final = DummyOperator(task_id='final', dag=dag)

    tasks = {}
    with TaskGroup("kedro", prefix_group_id=False) as kedro_group:
    {% if (pipeline_grouped|selectattr('group_type','eq','pyspark')|first) %}
        {% set init_script = config.run_config.spark.artifacts_path ~ '/' ~ project_name ~ '-' ~ git_info.commit_sha ~ '.sh' %}
        init_script = '{{ init_script }}'
    {%- if config.run_config.spark.type == 'dataproc' %}
        cluster_config = {{ config.run_config.spark.cluster_config | safe }}
        if 'initialization_actions' not in cluster_config:
            cluster_config['initialization_actions'] = []
        cluster_config['initialization_actions'].append({'executable_file':init_script})
        tasks["create-spark-cluster"] = DataprocCreateClusterOperator(
            cluster_name='{{ config.run_config.spark.cluster_name }}',
            region='{{ config.run_config.spark.region }}',
            project_id='{{ config.run_config.spark.project_id }}',
            cluster_config=cluster_config,
            task_id='create-spark-cluster'
        )
        tasks["delete-spark-cluster"] = DataprocDeleteClusterOperator(
            project_id='{{ config.run_config.spark.project_id }}',
            region='{{ config.run_config.spark.region }}',
            cluster_name='{{ config.run_config.spark.cluster_name }}',
            task_id='delete-spark-cluster',
            trigger_rule='all_done'
        )
    {%- elif config.run_config.spark.type == 'k8s' or config.run_config.spark.type == 'kubernetes' %}
        tasks["create-spark-cluster"] = DummyOperator(task_id='create-spark-cluster')
        tasks["delete-spark-cluster"] = DummyOperator(task_id='delete-spark-cluster')
    {%- elif custom_spark_factory %}
        tasks["create-spark-cluster"] = {{ custom_spark_factory.create_cluster_operator(project_name=project_name, config=config, init_script_path=init_script, cluster_config=config.run_config.spark.cluster_config) | safe | indent(8) }}
        tasks["delete-spark-cluster"] = {{ custom_spark_factory.delete_cluster_operator(project_name=project_name, config=config) | safe | indent(8) }}
    {% endif %}
        tasks["create-spark-cluster"] >> tasks["delete-spark-cluster"]
    {% endif %}
    {% for node in pipeline_grouped %}
        {%- if node.group_type == 'pyspark' %}
        {%- if config.run_config.spark.type == 'dataproc' %}
        PYSPARK_JOB = {
                        "reference": {"project_id": '{{ config.run_config.spark.project_id }}'},
                        "placement": {"cluster_name": '{{ config.run_config.spark.cluster_name }}'},
                        "pyspark_job": {"main_python_file_uri": '{{ config.run_config.spark.artifacts_path }}/{{ project_name }}-{{ git_info.commit_sha }}-{{ node.name }}.py',
                                        "properties": {
                                            {% if mlflow_url %}
                                            "spark.executorEnv.MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
                                            "spark.yarn.appMasterEnv.MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
                                            {% endif %}
                                        }
                        },
                    }
        tasks["{{ node.name | slugify }}"] = DataprocSubmitJobOperator(
            task_id="kedro-{{ node.name | slugify }}", job=PYSPARK_JOB, location='{{ config.run_config.spark.region }}', project_id='{{ config.run_config.spark.project_id }}'
        )
        {%- elif config.run_config.spark.type == 'k8s' or config.run_config.spark.type == 'kubernetes' %}
        tasks["{{ node.name | slugify }}"] = SparkSubmitK8SOperator(
                node_name='{{ node.name | slugify }}',
                kedro_script='{{ config.run_config.spark.cluster_config.run_script }}',
                image={% if config.run_config.spark.cluster_config.image %}"{{ config.run_config.spark.cluster_config.image }}"{%- elif k8s_templates[node.name].image -%}"{{ k8s_templates[node.name].image }}"{%- else -%}"{{ image }}"{%- endif %},
                run_name='{{ dag_name | slugify }}',
                nodes=[{% for task in node.task_group %}'{{ task.name }}',{% endfor %}],
                env='{{ env }}',
                conf={
                    {%- for key, value in config.run_config.spark.cluster_config.conf.items() %}
                    "{{ key }}": "{{ value }}",
                    {%- endfor %}
                },
                requests_cpu="{{ config.run_config.spark.cluster_config.requests.cpu if config.run_config.spark.cluster_config.requests.cpu }}",
                limits_cpu="{{ config.run_config.spark.cluster_config.limits.cpu if config.run_config.spark.cluster_config.limits.cpu }}",
                limits_memory="{{ config.run_config.spark.cluster_config.limits.memory if config.run_config.spark.cluster_config.limits.memory }}",
                image_pull_policy='{{ config.run_config.image_pull_policy }}',
                driver_port="{{ config.run_config.spark.cluster_config.driver_port if config.run_config.spark.cluster_config.driver_port }}",
                block_manager_port="{{ config.run_config.spark.cluster_config.block_manager_port if config.run_config.spark.cluster_config.block_manager_port }}",
                secrets={
                    {%- for key, value in config.run_config.spark.cluster_config.secrets.items() %}
                    "{{ key }}": "{{ value }}",
                    {%- endfor %}
                },
                labels={
                    {%- for key, value in config.run_config.spark.cluster_config.labels.items() %}
                    "{{ key }}": "{{ value }}",
                    {%- endfor %}
                },
                {%- if config.run_config.service_account_name %}
                service_account_name="{{ config.run_config.service_account_name }}",
                {%- endif %}
                local_storage_class_name="{{ config.run_config.spark.cluster_config.local_storage.class_name if config.run_config.spark.cluster_config.local_storage.class_name }}",
                local_storage_size="{{ config.run_config.spark.cluster_config.local_storage.size if config.run_config.spark.cluster_config.local_storage.size }}",
                namespace='{{ config.run_config.namespace }}',
                conn_id="{{ config.run_config.spark.cluster_name if config.run_config.spark.cluster_name }}",
                env_vars={
                    {%- if mlflow_url %}
                    "MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
                    {%- endif %}
                    {%- for key, value in config.run_config.spark.cluster_config.env_vars.items() %}
                    "{{ key }}": "{{ value }}",
                    {%- endfor %}
                },
                jars="{{ config.run_config.spark.cluster_config.jars if config.run_config.spark.cluster_config.jars }}",
                packages="{{ config.run_config.spark.cluster_config.packages if config.run_config.spark.cluster_config.packages }}",
                repositories="{{ config.run_config.spark.cluster_config.repositories if config.run_config.spark.cluster_config.repositories }}",
                num_executors="{{ config.run_config.spark.cluster_config.num_executors if config.run_config.spark.cluster_config.num_executors else "1" }}",
        )
        {%- elif custom_spark_factory %}
        {% set main_python_path = config.run_config.spark.artifacts_path ~ '/' ~ project_name ~ '-' ~ git_info.commit_sha ~ '-' ~ node.name ~ '.py' %}
        tasks["{{ node.name | slugify }}"] = {{ custom_spark_factory.submit_operator(project_name=project_name, pipeline=pipeline_name, node=node, config=config, main_python_file_path=main_python_path ) | safe  | indent(8)}}
        {%- endif %}
        tasks["{{ node.name | slugify }}"].doc_md = """
        Task automatically generated by kedro-airflow-k8s.
        This Spark task combines the following kedro tasks:
        {% for task in node.task_group %}
        - {{ task.name }}
        {%- endfor %}
        """
        tasks["create-spark-cluster"] >> tasks["{{ node.name | slugify }}"] >> tasks["delete-spark-cluster"]
        {%- else %}
        tasks["{{ node.name | slugify }}"] = NodePodOperator(
            node_name="{{node.name}}",
            namespace="{{config.run_config.namespace}}",
            volume_disabled={{config.run_config.volume.disabled}},
            pvc_name=pvc_name,
            image={%- if k8s_templates[node.name].image -%}
            "{{ k8s_templates[node.name].image }}",
            {%- else -%}
            "{{ image }}",
            {%- endif %}
            image_pull_policy="{{ config.run_config.image_pull_policy }}",
            {%- if config.run_config.image_pull_secrets %}
            image_pull_secrets="{{ config.run_config.image_pull_secrets }}",
            {%- endif %}
            {%- if config.run_config.service_account_name %}
            service_account_name="{{ config.run_config.service_account_name }}",
            {%- endif %}
            env="{{env}}",
            pipeline="{{ pipeline_name }}",
            task_id="{{ node.name | slugify }}",
            startup_timeout={{ config.run_config.startup_timeout }},
            volume_owner={{ config.run_config.volume.owner }},
            mlflow_enabled={% if mlflow_url %}True{% else %}False{% endif %},
            requests_cpu="{{resources[node.name].requests.cpu if resources[node.name].requests.cpu}}",
            requests_memory="{{resources[node.name].requests.memory if resources[node.name].requests.memory}}",
            limits_cpu="{{resources[node.name].limits.cpu if resources[node.name].limits.cpu}}",
            limits_memory="{{resources[node.name].limits.memory if resources[node.name].limits.memory}}",
            node_selector_labels={
            {%- for key, value in resources[node.name].node_selectors.items() %}
                "{{ key }}": "{{ value }}",
                {%- endfor %}
            },
            labels={
            {%- for key, value in resources[node.name].labels.items() %}
                "{{ key }}": "{{ value }}",
                {%- endfor %}
            },
            tolerations=[
            {%- for tolerations_items in resources[node.name].tolerations %}
                {
                {%- for key, value in tolerations_items.items() %}
                   "{{ key }}": "{{ value }}",
                {%- endfor %}
                },
            {%- endfor %}
            ],
            annotations={
            {%- for key, value in resources[node.name].annotations.items() %}
                "{{ key }}": """{{ value | safe }}""",
            {%- endfor %}
            },
            secrets=[
            {%- for secret in secrets %}
                Secret("{{secret.deploy_type}}", {{  "\"{}\"".format(secret.deploy_target)|safe if secret.deploy_target else None}}, "{{secret.secret}}", {{ "\"{}\"".format(secret.key)|safe if secret.key else None}}),
            {%- endfor %}
            ],
            parameters="""
            {%- for param in macro_params %}
                {{ param }}:{{'{{'}} {{  param }} {{ '}}' }},
                {%- endfor %}
            {%- for param in variables_params %}
                {{ param }}:{{'{{ var.value.'}}{{  param }} {{ '}}' }},
                {%- endfor %}
            """,
            env_vars={
            {%- for var in config.run_config.env_vars %}
                "{{ var }}": "{{'{{ var.value.'}}{{ var }} {{ '}}' }}",
            {%- endfor %}
            },
            {%- if k8s_templates[node.name].template %}
            kubernetes_pod_template=f"""{{ k8s_templates[node.name].template | safe }}""",
            {%- endif %}
        )
        {%- endif %}
    {% endfor %}


    {% for task_group in pipeline_grouped -%}
    {% for child in task_group.children %}
        tasks["{{ task_group.name | slugify }}"] >> tasks["{{ child.name | slugify }}"]
    {% endfor %}
    {%- endfor %}

    {% if with_external_dependencies %}
    {% for dependency in config.run_config.external_dependencies %}
    with TaskGroup("external_pipelines", prefix_group_id=False):
        ExternalTaskSensor(external_dag_id='{{ dependency.dag_id }}',
                            external_task_id='{{ dependency.task_id if dependency.task_id }}',
                            check_existence={{ dependency.check_existence }},
                            execution_delta=timedelta(minutes={{ dependency.execution_delta }}),
                            task_id='wait_for_{{ dependency.dag_id }}_{{ dependency.task_id }}',
                            timeout={{ dependency.timeout }} * 60,
                            ) >> kedro_group
    {% endfor %}
    {% endif %}

    {% if not config.run_config.volume.disabled %}
        {% if not config.run_config.volume.skip_init %}
    create_pipeline_storage >> data_volume_init
    data_volume_init >> delete_pipeline_storage
        {% else %}
    create_pipeline_storage >> delete_pipeline_storage
        {% endif %}
    {% endif %}

    {% if mlflow_url %}
    start_mlflow_run >> kedro_group
    {% endif %}
    {% if not config.run_config.volume.disabled and not config.run_config.volume.skip_init %}
    data_volume_init >> kedro_group
    {% endif %}

    {% if not config.run_config.volume.disabled %}
    kedro_group >> delete_pipeline_storage >> final
    {% endif %}
    kedro_group >> final

if __name__ == "__main__":
    dag.cli()