getindata/kedro-airflow-k8s

View on GitHub
kedro_airflow_k8s/airflow_spark_task_template.j2

Summary

Maintainability
Test Coverage
import logging.config
import os
import sys
from typing import Any, Dict
from kedro.framework.startup import _get_project_metadata
from kedro.runner import ThreadRunner

startup_error = None
project_path = os.getenv("PROJECT_HOME", "/opt/{{ project_name }}")

def init_kedro(path, env: str = None, extra_params: Dict[str, Any] = None):
    """Line magic which reloads all Kedro default variables."""
    global startup_error
    global context
    global catalog
    global session

    try:
        import kedro.config.default_logger
        from kedro.framework.hooks import get_hook_manager
        from kedro.framework.project import configure_project
        from kedro.framework.session import KedroSession
        from kedro.framework.session.session import _activate_session
    except ImportError:
        logging.error(
            "Kedro appears not to be installed in your current environment "
            "or your current IPython session was not started in a valid Kedro project."
        )
        raise

    try:
        path = path or project_path

        # clear hook manager
        hook_manager = get_hook_manager()
        name_plugin_pairs = hook_manager.list_name_plugin()
        for name, plugin in name_plugin_pairs:
            hook_manager.unregister(name=name, plugin=plugin)

        # remove cached user modules
        metadata = _get_project_metadata(path)
        to_remove = [
            mod for mod in sys.modules if mod.startswith(metadata.package_name)
        ]
        # `del` is used instead of `reload()` because: If the new version of a module does not
        # define a name that was defined by the old version, the old definition remains.
        for module in to_remove:
            del sys.modules[module]

        configure_project(metadata.package_name)
        session = KedroSession.create(
            metadata.package_name, path, env=env, extra_params=extra_params
        )
        _activate_session(session, force=True)
        logging.debug("Loading the context from %s", str(path))
        context = session.load_context()
        catalog = context.catalog

        logging.info("** Kedro project %s", str(metadata.project_name))
        logging.info("Defined global variable `context`, `session` and `catalog`")

    except Exception as err:
        startup_error = err
        logging.exception(
            "Kedro's session startup script failed:\n%s", str(err)
        )
        raise err


if __name__ == "__main__":
    init_kedro(project_path{% if kedro_env %}, env='{{ kedro_env }}'{% endif %})
    session.run(runner=ThreadRunner(), node_names=[
            {%- for node_name in node_names -%}
            "{{ node_name }}",
        {%- endfor -%}
    ])