kedro_airflow_k8s/airflow_dag_template.j2
import logging
from airflow import DAG
from airflow.kubernetes.secret import Secret
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.operators.dummy import DummyOperator
from datetime import timedelta
from datetime import datetime
{% if config.run_config.spark.type == 'dataproc' %}
from airflow.providers.google.cloud.operators.dataproc import (
DataprocSubmitJobOperator,
DataprocCreateClusterOperator,
DataprocDeleteClusterOperator
)
{% elif config.run_config.spark.type in ['kubernetes','k8s'] %}
{{ include_spark_submit_k8s_operator | safe }}
{% endif %}
EXPERIMENT_NAME = "{{ config.run_config.experiment_name | slugify }}"
{% if config.run_config.failure_handlers is defined and config.run_config.failure_handlers|length > 0 -%}
def task_fail_handler(context):
{%- for handler in config.run_config.failure_handlers %}
{%- if handler.type == "slack" %}
from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator
from airflow.hooks.base import BaseHook
SlackWebhookOperator(
task_id='failure_slack_notification',
http_conn_id='{{ handler.connection_id }}',
webhook_token=BaseHook.get_connection('{{ handler.connection_id }}').password,
message="""{{ handler.message_template }}""".format(
task=context.get('task_instance').task_id,
dag=context.get('task_instance').dag_id,
execution_time=context.get('execution_date'),
url=context.get('task_instance').log_url),
username=BaseHook.get_connection('{{ handler.connection_id }}').login
).execute(context=context)
{% endif %}
{% endfor %}
{% endif %}
args = {
'owner': 'airflow',
{%- if config.run_config.failure_handlers is defined and config.run_config.failure_handlers|length > 0 %}
'on_failure_callback': task_fail_handler,
{% endif %}
}
{% if mlflow_url %}
{{ include_start_mlflow_experiment_operator | safe }}
{% endif %}
{% if not config.run_config.volume.disabled %}
{{ include_create_pipeline_storage_operator | safe }}
{{ include_delete_pipeline_storage_operator | safe }}
{% if not config.run_config.volume.skip_init %}
{{ include_data_volume_init_operator | safe }}
{% endif %}
{% endif %}
{{ include_node_pod_operator | safe }}
with DAG(
dag_id='{{ dag_name }}',
description='{{ config.run_config.description }}',
default_args=args,
schedule_interval={% if schedule_interval %}'{{ schedule_interval }}'{% else %}None{% endif %},
start_date={% if config.run_config.start_date %}datetime({{ config.run_config.start_date[0:4] }}, int('{{ config.run_config.start_date[4:6] }}'), int('{{ config.run_config.start_date[6:8] }}')){% else %}days_ago(2){% endif %},
tags=['commit_sha:{{ git_info.commit_sha }}',
'generated_with_kedro_airflow_k8s:{{ kedro_airflow_k8s_version }}',
'experiment_name:'+EXPERIMENT_NAME],
params={},
) as dag:
pvc_name = '{{ project_name | safe | slugify }}.{% raw %}{{ ts_nodash | lower }}{% endraw %}'
{% if mlflow_url %}
start_mlflow_run = StartMLflowExperimentOperator(
experiment_name=EXPERIMENT_NAME,
mlflow_url='{{ mlflow_url }}',
image="{{ image }}",
auth_handler={{ config.run_config.auth_config.type }}AuthHandler({%- if config.run_config.auth_config.params -%}[
{%- for param in config.run_config.auth_config.params -%}
"{{ param }}",
{%- endfor -%}]
{%- endif -%})
)
{% endif %}
{% if not config.run_config.volume.disabled %}
create_pipeline_storage = CreatePipelineStorageOperator(
pvc_name=pvc_name,
namespace='{{ config.run_config.namespace }}',
access_modes={{config.run_config.volume.access_modes | safe}},
volume_size='{{ config.run_config.volume.size }}',
storage_class_name='{{ config.run_config.volume.storageclass }}'
)
delete_pipeline_storage=DeletePipelineStorageOperator(
pvc_name=pvc_name,
namespace='{{ config.run_config.namespace }}'
)
{% if not config.run_config.volume.skip_init %}
data_volume_init=DataVolumeInitOperator(
namespace='{{ config.run_config.namespace }}',
pvc_name=pvc_name,
image='{{ image }}',
image_pull_policy='{{ config.run_config.image_pull_policy }}',
startup_timeout={{ config.run_config.startup_timeout }},
volume_owner={{ config.run_config.volume.owner }},
{%- if config.run_config.image_pull_secrets %}
image_pull_secrets="{{ config.run_config.image_pull_secrets }}",
{%- endif %}
{%- if config.run_config.service_account_name %}
service_account_name="{{ config.run_config.service_account_name }}",
{%- endif %}
)
{% endif %}
{% endif %}
{%- if custom_spark_factory %}
{{ custom_spark_factory.imports_statement | indent() }}
{%- endif %}
final = DummyOperator(task_id='final', dag=dag)
tasks = {}
with TaskGroup("kedro", prefix_group_id=False) as kedro_group:
{% if (pipeline_grouped|selectattr('group_type','eq','pyspark')|first) %}
{% set init_script = config.run_config.spark.artifacts_path ~ '/' ~ project_name ~ '-' ~ git_info.commit_sha ~ '.sh' %}
init_script = '{{ init_script }}'
{%- if config.run_config.spark.type == 'dataproc' %}
cluster_config = {{ config.run_config.spark.cluster_config | safe }}
if 'initialization_actions' not in cluster_config:
cluster_config['initialization_actions'] = []
cluster_config['initialization_actions'].append({'executable_file':init_script})
tasks["create-spark-cluster"] = DataprocCreateClusterOperator(
cluster_name='{{ config.run_config.spark.cluster_name }}',
region='{{ config.run_config.spark.region }}',
project_id='{{ config.run_config.spark.project_id }}',
cluster_config=cluster_config,
task_id='create-spark-cluster'
)
tasks["delete-spark-cluster"] = DataprocDeleteClusterOperator(
project_id='{{ config.run_config.spark.project_id }}',
region='{{ config.run_config.spark.region }}',
cluster_name='{{ config.run_config.spark.cluster_name }}',
task_id='delete-spark-cluster',
trigger_rule='all_done'
)
{%- elif config.run_config.spark.type == 'k8s' or config.run_config.spark.type == 'kubernetes' %}
tasks["create-spark-cluster"] = DummyOperator(task_id='create-spark-cluster')
tasks["delete-spark-cluster"] = DummyOperator(task_id='delete-spark-cluster')
{%- elif custom_spark_factory %}
tasks["create-spark-cluster"] = {{ custom_spark_factory.create_cluster_operator(project_name=project_name, config=config, init_script_path=init_script, cluster_config=config.run_config.spark.cluster_config) | safe | indent(8) }}
tasks["delete-spark-cluster"] = {{ custom_spark_factory.delete_cluster_operator(project_name=project_name, config=config) | safe | indent(8) }}
{% endif %}
tasks["create-spark-cluster"] >> tasks["delete-spark-cluster"]
{% endif %}
{% for node in pipeline_grouped %}
{%- if node.group_type == 'pyspark' %}
{%- if config.run_config.spark.type == 'dataproc' %}
PYSPARK_JOB = {
"reference": {"project_id": '{{ config.run_config.spark.project_id }}'},
"placement": {"cluster_name": '{{ config.run_config.spark.cluster_name }}'},
"pyspark_job": {"main_python_file_uri": '{{ config.run_config.spark.artifacts_path }}/{{ project_name }}-{{ git_info.commit_sha }}-{{ node.name }}.py',
"properties": {
{% if mlflow_url %}
"spark.executorEnv.MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
"spark.yarn.appMasterEnv.MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
{% endif %}
}
},
}
tasks["{{ node.name | slugify }}"] = DataprocSubmitJobOperator(
task_id="kedro-{{ node.name | slugify }}", job=PYSPARK_JOB, location='{{ config.run_config.spark.region }}', project_id='{{ config.run_config.spark.project_id }}'
)
{%- elif config.run_config.spark.type == 'k8s' or config.run_config.spark.type == 'kubernetes' %}
tasks["{{ node.name | slugify }}"] = SparkSubmitK8SOperator(
node_name='{{ node.name | slugify }}',
kedro_script='{{ config.run_config.spark.cluster_config.run_script }}',
image={% if config.run_config.spark.cluster_config.image %}"{{ config.run_config.spark.cluster_config.image }}"{%- elif k8s_templates[node.name].image -%}"{{ k8s_templates[node.name].image }}"{%- else -%}"{{ image }}"{%- endif %},
run_name='{{ dag_name | slugify }}',
nodes=[{% for task in node.task_group %}'{{ task.name }}',{% endfor %}],
env='{{ env }}',
conf={
{%- for key, value in config.run_config.spark.cluster_config.conf.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
requests_cpu="{{ config.run_config.spark.cluster_config.requests.cpu if config.run_config.spark.cluster_config.requests.cpu }}",
limits_cpu="{{ config.run_config.spark.cluster_config.limits.cpu if config.run_config.spark.cluster_config.limits.cpu }}",
limits_memory="{{ config.run_config.spark.cluster_config.limits.memory if config.run_config.spark.cluster_config.limits.memory }}",
image_pull_policy='{{ config.run_config.image_pull_policy }}',
driver_port="{{ config.run_config.spark.cluster_config.driver_port if config.run_config.spark.cluster_config.driver_port }}",
block_manager_port="{{ config.run_config.spark.cluster_config.block_manager_port if config.run_config.spark.cluster_config.block_manager_port }}",
secrets={
{%- for key, value in config.run_config.spark.cluster_config.secrets.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
labels={
{%- for key, value in config.run_config.spark.cluster_config.labels.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
{%- if config.run_config.service_account_name %}
service_account_name="{{ config.run_config.service_account_name }}",
{%- endif %}
local_storage_class_name="{{ config.run_config.spark.cluster_config.local_storage.class_name if config.run_config.spark.cluster_config.local_storage.class_name }}",
local_storage_size="{{ config.run_config.spark.cluster_config.local_storage.size if config.run_config.spark.cluster_config.local_storage.size }}",
namespace='{{ config.run_config.namespace }}',
conn_id="{{ config.run_config.spark.cluster_name if config.run_config.spark.cluster_name }}",
env_vars={
{%- if mlflow_url %}
"MLFLOW_RUN_ID": "{% raw %}{{ ti.xcom_pull(key='mlflow_run_id') }}{% endraw %}",
{%- endif %}
{%- for key, value in config.run_config.spark.cluster_config.env_vars.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
jars="{{ config.run_config.spark.cluster_config.jars if config.run_config.spark.cluster_config.jars }}",
packages="{{ config.run_config.spark.cluster_config.packages if config.run_config.spark.cluster_config.packages }}",
repositories="{{ config.run_config.spark.cluster_config.repositories if config.run_config.spark.cluster_config.repositories }}",
num_executors="{{ config.run_config.spark.cluster_config.num_executors if config.run_config.spark.cluster_config.num_executors else "1" }}",
)
{%- elif custom_spark_factory %}
{% set main_python_path = config.run_config.spark.artifacts_path ~ '/' ~ project_name ~ '-' ~ git_info.commit_sha ~ '-' ~ node.name ~ '.py' %}
tasks["{{ node.name | slugify }}"] = {{ custom_spark_factory.submit_operator(project_name=project_name, pipeline=pipeline_name, node=node, config=config, main_python_file_path=main_python_path ) | safe | indent(8)}}
{%- endif %}
tasks["{{ node.name | slugify }}"].doc_md = """
Task automatically generated by kedro-airflow-k8s.
This Spark task combines the following kedro tasks:
{% for task in node.task_group %}
- {{ task.name }}
{%- endfor %}
"""
tasks["create-spark-cluster"] >> tasks["{{ node.name | slugify }}"] >> tasks["delete-spark-cluster"]
{%- else %}
tasks["{{ node.name | slugify }}"] = NodePodOperator(
node_name="{{node.name}}",
namespace="{{config.run_config.namespace}}",
volume_disabled={{config.run_config.volume.disabled}},
pvc_name=pvc_name,
image={%- if k8s_templates[node.name].image -%}
"{{ k8s_templates[node.name].image }}",
{%- else -%}
"{{ image }}",
{%- endif %}
image_pull_policy="{{ config.run_config.image_pull_policy }}",
{%- if config.run_config.image_pull_secrets %}
image_pull_secrets="{{ config.run_config.image_pull_secrets }}",
{%- endif %}
{%- if config.run_config.service_account_name %}
service_account_name="{{ config.run_config.service_account_name }}",
{%- endif %}
env="{{env}}",
pipeline="{{ pipeline_name }}",
task_id="{{ node.name | slugify }}",
startup_timeout={{ config.run_config.startup_timeout }},
volume_owner={{ config.run_config.volume.owner }},
mlflow_enabled={% if mlflow_url %}True{% else %}False{% endif %},
requests_cpu="{{resources[node.name].requests.cpu if resources[node.name].requests.cpu}}",
requests_memory="{{resources[node.name].requests.memory if resources[node.name].requests.memory}}",
limits_cpu="{{resources[node.name].limits.cpu if resources[node.name].limits.cpu}}",
limits_memory="{{resources[node.name].limits.memory if resources[node.name].limits.memory}}",
node_selector_labels={
{%- for key, value in resources[node.name].node_selectors.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
labels={
{%- for key, value in resources[node.name].labels.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
tolerations=[
{%- for tolerations_items in resources[node.name].tolerations %}
{
{%- for key, value in tolerations_items.items() %}
"{{ key }}": "{{ value }}",
{%- endfor %}
},
{%- endfor %}
],
annotations={
{%- for key, value in resources[node.name].annotations.items() %}
"{{ key }}": """{{ value | safe }}""",
{%- endfor %}
},
secrets=[
{%- for secret in secrets %}
Secret("{{secret.deploy_type}}", {{ "\"{}\"".format(secret.deploy_target)|safe if secret.deploy_target else None}}, "{{secret.secret}}", {{ "\"{}\"".format(secret.key)|safe if secret.key else None}}),
{%- endfor %}
],
parameters="""
{%- for param in macro_params %}
{{ param }}:{{'{{'}} {{ param }} {{ '}}' }},
{%- endfor %}
{%- for param in variables_params %}
{{ param }}:{{'{{ var.value.'}}{{ param }} {{ '}}' }},
{%- endfor %}
""",
env_vars={
{%- for var in config.run_config.env_vars %}
"{{ var }}": "{{'{{ var.value.'}}{{ var }} {{ '}}' }}",
{%- endfor %}
},
{%- if k8s_templates[node.name].template %}
kubernetes_pod_template=f"""{{ k8s_templates[node.name].template | safe }}""",
{%- endif %}
)
{%- endif %}
{% endfor %}
{% for task_group in pipeline_grouped -%}
{% for child in task_group.children %}
tasks["{{ task_group.name | slugify }}"] >> tasks["{{ child.name | slugify }}"]
{% endfor %}
{%- endfor %}
{% if with_external_dependencies %}
{% for dependency in config.run_config.external_dependencies %}
with TaskGroup("external_pipelines", prefix_group_id=False):
ExternalTaskSensor(external_dag_id='{{ dependency.dag_id }}',
external_task_id='{{ dependency.task_id if dependency.task_id }}',
check_existence={{ dependency.check_existence }},
execution_delta=timedelta(minutes={{ dependency.execution_delta }}),
task_id='wait_for_{{ dependency.dag_id }}_{{ dependency.task_id }}',
timeout={{ dependency.timeout }} * 60,
) >> kedro_group
{% endfor %}
{% endif %}
{% if not config.run_config.volume.disabled %}
{% if not config.run_config.volume.skip_init %}
create_pipeline_storage >> data_volume_init
data_volume_init >> delete_pipeline_storage
{% else %}
create_pipeline_storage >> delete_pipeline_storage
{% endif %}
{% endif %}
{% if mlflow_url %}
start_mlflow_run >> kedro_group
{% endif %}
{% if not config.run_config.volume.disabled and not config.run_config.volume.skip_init %}
data_volume_init >> kedro_group
{% endif %}
{% if not config.run_config.volume.disabled %}
kedro_group >> delete_pipeline_storage >> final
{% endif %}
kedro_group >> final
if __name__ == "__main__":
dag.cli()