diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index d322519230d25..a263fa9106a11 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -35,6 +35,10 @@ DatabricksWorkflowTaskGroup, WorkflowRunMetadata, ) +from airflow.providers.databricks.plugins.databricks_workflow import ( + WorkflowJobRepairSingleTaskLink, + WorkflowJobRunLink, +) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event @@ -958,6 +962,15 @@ def __init__( super().__init__(**kwargs) + if self._databricks_workflow_task_group is not None: + self.operator_extra_links = ( + WorkflowJobRunLink(), + WorkflowJobRepairSingleTaskLink(), + ) + else: + # Databricks does not support repair for non-workflow tasks, hence do not show the repair link. + self.operator_extra_links = (DatabricksJobRunLink(),) + @cached_property def _hook(self) -> DatabricksHook: return self._get_hook(caller=self.caller) @@ -1016,12 +1029,17 @@ def _get_run_json(self) -> dict[str, Any]: raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def _launch_job(self) -> int: + def _launch_job(self, context: Context | None = None) -> int: """Launch the job on Databricks.""" run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) self.log.info("Check the job run in Databricks: %s", url) + + if self.do_xcom_push and context is not None: + context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=self.databricks_run_id) + context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=url) + return self.databricks_run_id def _handle_terminal_run_state(self, run_state: RunState) -> None: @@ -1040,7 +1058,15 @@ def _get_current_databricks_task(self) -> dict[str, Any]: """Retrieve the Databricks task corresponding to the current Airflow task.""" if self.databricks_run_id is None: raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") - return {task["task_key"]: task for task in self._hook.get_run(self.databricks_run_id)["tasks"]}[ + tasks = self._hook.get_run(self.databricks_run_id)["tasks"] + + # Because the task_key remains the same across multiple runs, and the Databricks API does not return + # tasks sorted by their attempts/start time, we sort the tasks by start time. This ensures that we + # map the latest attempt (whose status is to be monitored) of the task run to the task_key while + # building the {task_key: task} map below. + sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"]) + + return {task["task_key"]: task for task in sorted_task_runs}[ self._get_databricks_task_id(self.task_id) ] @@ -1125,7 +1151,7 @@ def execute(self, context: Context) -> None: self.databricks_run_id = workflow_run_metadata.run_id self.databricks_conn_id = workflow_run_metadata.conn_id else: - self._launch_job() + self._launch_job(context=context) if self.wait_for_termination: self.monitor_databricks_job() diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py index 8203145314fd0..15333dc69118b 100644 --- a/airflow/providers/databricks/operators/databricks_workflow.py +++ b/airflow/providers/databricks/operators/databricks_workflow.py @@ -28,6 +28,10 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState +from airflow.providers.databricks.plugins.databricks_workflow import ( + WorkflowJobRepairAllFailedLink, + WorkflowJobRunLink, +) from airflow.utils.task_group import TaskGroup if TYPE_CHECKING: @@ -88,6 +92,7 @@ class _CreateDatabricksWorkflowOperator(BaseOperator): populated after instantiation using the `add_task` method. """ + operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink()) template_fields = ("notebook_params",) caller = "_CreateDatabricksWorkflowOperator" diff --git a/airflow/providers/databricks/plugins/__init__.py b/airflow/providers/databricks/plugins/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/databricks/plugins/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/databricks/plugins/databricks_workflow.py b/airflow/providers/databricks/plugins/databricks_workflow.py new file mode 100644 index 0000000000000..41c7b6735759f --- /dev/null +++ b/airflow/providers/databricks/plugins/databricks_workflow.py @@ -0,0 +1,479 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import unquote + +from flask import current_app, flash, redirect, request, url_for +from flask_appbuilder.api import expose +from packaging.version import Version + +from airflow.configuration import conf +from airflow.exceptions import AirflowException, TaskInstanceNotFound +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.dag import DAG, clear_task_instances +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey +from airflow.models.xcom import XCom +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.databricks.hooks.databricks import DatabricksHook +from airflow.security import permissions +from airflow.utils.airflow_flask_app import AirflowApp +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import TaskGroup +from airflow.version import version +from airflow.www import auth +from airflow.www.views import AirflowBaseView + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + +REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20) +REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5) + +airflow_app = cast(AirflowApp, current_app) + + +def get_auth_decorator(): + # TODO: remove this if block when min_airflow_version is set to higher than 2.8.0 + if Version(version) < Version("2.8"): + return auth.has_access( + [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + ] + ) + + from airflow.auth.managers.models.resource_details import DagAccessEntity + + return auth.has_access_dag("POST", DagAccessEntity.RUN) + + +def _get_databricks_task_id(task: BaseOperator) -> str: + """ + Get the databricks task ID using dag_id and task_id. removes illegal characters. + + :param task: The task to get the databricks task ID for. + :return: The databricks task ID. + """ + return f"{task.dag_id}__{task.task_id.replace('.', '__')}" + + +def get_databricks_task_ids( + group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger +) -> list[str]: + """ + Return a list of all Databricks task IDs for a dictionary of Airflow tasks. + + :param group_id: The task group ID. + :param task_map: A dictionary mapping task IDs to BaseOperator instances. + :param log: The logger to use for logging. + :return: A list of Databricks task IDs for the given task group. + """ + task_ids = [] + log.debug("Getting databricks task ids for group %s", group_id) + for task_id, task in task_map.items(): + if task_id == f"{group_id}.launch": + continue + databricks_task_id = _get_databricks_task_id(task) + log.debug("databricks task id for task %s is %s", task_id, databricks_task_id) + task_ids.append(databricks_task_id) + return task_ids + + +@provide_session +def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun: + """ + Retrieve the DagRun object associated with the specified DAG and run_id. + + :param dag: The DAG object associated with the DagRun to retrieve. + :param run_id: The run_id associated with the DagRun to retrieve. + :param session: The SQLAlchemy session to use for the query. If None, uses the default session. + :return: The DagRun object associated with the specified DAG and run_id. + """ + if not session: + raise AirflowException("Session not provided.") + + return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() + + +@provide_session +def _clear_task_instances( + dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None +) -> None: + dag = airflow_app.dag_bag.get_dag(dag_id) + log.debug("task_ids %s to clear", str(task_ids)) + dr: DagRun = _get_dagrun(dag, run_id, session=session) + tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids] + clear_task_instances(tis_to_clear, session) + + +def _repair_task( + databricks_conn_id: str, + databricks_run_id: int, + tasks_to_repair: list[str], + logger: logging.Logger, +) -> int: + """ + Repair a Databricks task using the Databricks API. + + This function allows the Airflow retry function to create a repair job for Databricks. + It uses the Databricks API to get the latest repair ID before sending the repair query. + + :param databricks_conn_id: The Databricks connection ID. + :param databricks_run_id: The Databricks run ID. + :param tasks_to_repair: A list of Databricks task IDs to repair. + :param logger: The logger to use for logging. + :return: None + """ + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) + + repair_history_id = hook.get_latest_repair_id(databricks_run_id) + logger.debug("Latest repair ID is %s", repair_history_id) + logger.debug( + "Sending repair query for tasks %s on run %s", + tasks_to_repair, + databricks_run_id, + ) + + repair_json = { + "run_id": databricks_run_id, + "latest_repair_id": repair_history_id, + "rerun_tasks": tasks_to_repair, + } + + return hook.repair_run(repair_json) + + +def get_launch_task_id(task_group: TaskGroup) -> str: + """ + Retrieve the launch task ID from the current task group or a parent task group, recursively. + + :param task_group: Task Group to be inspected + :return: launch Task ID + """ + try: + launch_task_id = task_group.get_child_by_label("launch").task_id # type: ignore[attr-defined] + except KeyError as e: + if not task_group.parent_group: + raise AirflowException("No launch task can be found in the task group.") from e + launch_task_id = get_launch_task_id(task_group.parent_group) + + return launch_task_id + + +def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> TaskInstanceKey: + """ + Return the task key for the launch task. + + This allows us to gather databricks Metadata even if the current task has failed (since tasks only + create xcom values if they succeed). + + :param current_task_key: The task key for the current task. + :param task_id: The task ID for the current task. + :return: The task key for the launch task. + """ + if task_id: + return TaskInstanceKey( + dag_id=current_task_key.dag_id, + task_id=task_id, + run_id=current_task_key.run_id, + try_number=current_task_key.try_number, + ) + + return current_task_key + + +@provide_session +def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: + dag_id = operator.dag.dag_id + dag_run = DagRun.find(dag_id, execution_date=dttm)[0] + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == dag_run.run_id, + TaskInstance.task_id == operator.task_id, + ) + .one_or_none() + ) + if not ti: + raise TaskInstanceNotFound("Task instance not found") + return ti + + +def get_xcom_result( + ti_key: TaskInstanceKey, + key: str, +) -> Any: + result = XCom.get_value( + ti_key=ti_key, + key=key, + ) + from airflow.providers.databricks.operators.databricks_workflow import WorkflowRunMetadata + + return WorkflowRunMetadata(**result) + + +class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin): + """Constructs a link to monitor a Databricks Job Run.""" + + name = "See Databricks Job Run" + + def get_link( + self, + operator: BaseOperator, + dttm=None, + *, + ti_key: TaskInstanceKey | None = None, + ) -> str: + if not ti_key: + ti = get_task_instance(operator, dttm) + ti_key = ti.key + task_group = operator.task_group + + if not task_group: + raise AirflowException("Task group is required for generating Databricks Workflow Job Run Link.") + + dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) + dag.get_task(ti_key.task_id) + self.log.info("Getting link for task %s", ti_key.task_id) + if ".launch" not in ti_key.task_id: + self.log.debug("Finding the launch task for job run metadata %s", ti_key.task_id) + launch_task_id = get_launch_task_id(task_group) + ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) + metadata = get_xcom_result(ti_key, "return_value") + + hook = DatabricksHook(metadata.conn_id) + return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}" + + +class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin): + """Constructs a link to send a request to repair all failed tasks in the Databricks workflow.""" + + name = "Repair All Failed Tasks" + + def get_link( + self, + operator, + dttm=None, + *, + ti_key: TaskInstanceKey | None = None, + ) -> str: + if not ti_key: + ti = get_task_instance(operator, dttm) + ti_key = ti.key + task_group = operator.task_group + self.log.debug( + "Creating link to repair all tasks for databricks job run %s", + task_group.group_id, + ) + + metadata = get_xcom_result(ti_key, "return_value") + + tasks_str = self.get_tasks_to_run(ti_key, operator, self.log) + self.log.debug("tasks to rerun: %s", tasks_str) + + query_params = { + "dag_id": ti_key.dag_id, + "databricks_conn_id": metadata.conn_id, + "databricks_run_id": metadata.run_id, + "run_id": ti_key.run_id, + "tasks_to_repair": tasks_str, + } + + return url_for("RepairDatabricksTasks.repair", **query_params) + + @classmethod + def get_task_group_children(cls, task_group: TaskGroup) -> dict[str, BaseOperator]: + """ + Given a TaskGroup, return children which are Tasks, inspecting recursively any TaskGroups within. + + :param task_group: An Airflow TaskGroup + :return: Dictionary that contains Task IDs as keys and Tasks as values. + """ + children: dict[str, Any] = {} + for child_id, child in task_group.children.items(): + if isinstance(child, TaskGroup): + child_children = cls.get_task_group_children(child) + children = {**children, **child_children} + else: + children[child_id] = child + return children + + def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log: logging.Logger) -> str: + task_group = operator.task_group + if not task_group: + raise AirflowException("Task group is required for generating repair link.") + if not task_group.group_id: + raise AirflowException("Task group ID is required for generating repair link.") + dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) + dr = _get_dagrun(dag, ti_key.run_id) + log.debug("Getting failed and skipped tasks for dag run %s", dr.run_id) + task_group_sub_tasks = self.get_task_group_children(task_group).items() + failed_and_skipped_tasks = self._get_failed_and_skipped_tasks(dr) + log.debug("Failed and skipped tasks: %s", failed_and_skipped_tasks) + + tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks} + + return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) + + @staticmethod + def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]: + """ + Return a list of task IDs for tasks that have failed or have been skipped in the given DagRun. + + :param dr: The DagRun object for which to retrieve failed and skipped tasks. + + :return: A list of task IDs for tasks that have failed or have been skipped. + """ + return [ + t.task_id + for t in dr.get_task_instances( + state=[ + TaskInstanceState.FAILED, + TaskInstanceState.SKIPPED, + TaskInstanceState.UP_FOR_RETRY, + TaskInstanceState.UPSTREAM_FAILED, + None, + ], + ) + ] + + +class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, LoggingMixin): + """Construct a link to send a repair request for a single databricks task.""" + + name = "Repair a single task" + + def get_link( + self, + operator, + dttm=None, + *, + ti_key: TaskInstanceKey | None = None, + ) -> str: + if not ti_key: + ti = get_task_instance(operator, dttm) + ti_key = ti.key + + task_group = operator.task_group + if not task_group: + raise AirflowException("Task group is required for generating repair link.") + + self.log.info( + "Creating link to repair a single task for databricks job run %s task %s", + task_group.group_id, + ti_key.task_id, + ) + dag = airflow_app.dag_bag.get_dag(ti_key.dag_id) + task = dag.get_task(ti_key.task_id) + + if ".launch" not in ti_key.task_id: + launch_task_id = get_launch_task_id(task_group) + ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) + metadata = get_xcom_result(ti_key, "return_value") + + query_params = { + "dag_id": ti_key.dag_id, + "databricks_conn_id": metadata.conn_id, + "databricks_run_id": metadata.run_id, + "run_id": ti_key.run_id, + "tasks_to_repair": _get_databricks_task_id(task), + } + return url_for("RepairDatabricksTasks.repair", **query_params) + + +class RepairDatabricksTasks(AirflowBaseView, LoggingMixin): + """Repair databricks tasks from Airflow.""" + + default_view = "repair" + + @expose("/repair_databricks_job//", methods=("GET",)) + @get_auth_decorator() + def repair(self, dag_id: str, run_id: str): + view = conf.get("webserver", "dag_default_view") + return_url = self._get_return_url(dag_id, view) + + tasks_to_repair = request.values.get("tasks_to_repair") + self.log.info("Tasks to repair: %s", tasks_to_repair) + if not tasks_to_repair: + flash("No tasks to repair. Not sending repair request.") + return redirect(return_url) + + databricks_conn_id = request.values.get("databricks_conn_id") + databricks_run_id = request.values.get("databricks_run_id") + + if not databricks_conn_id: + flash("No Databricks connection ID provided. Cannot repair tasks.") + return redirect(return_url) + + if not databricks_run_id: + flash("No Databricks run ID provided. Cannot repair tasks.") + return redirect(return_url) + + self.log.info("Repairing databricks job %s", databricks_run_id) + res = _repair_task( + databricks_conn_id=databricks_conn_id, + databricks_run_id=int(databricks_run_id), + tasks_to_repair=tasks_to_repair.split(","), + logger=self.log, + ) + self.log.info("Repairing databricks job query for run %s sent", databricks_run_id) + + self.log.info("Clearing tasks to rerun in airflow") + + run_id = unquote(run_id) + _clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log) + flash(f"Databricks repair job is starting!: {res}") + return redirect(return_url) + + @staticmethod + def _get_return_url(dag_id: str, view) -> str: + return f"/dags/{dag_id}/{view}" + + +repair_databricks_view = RepairDatabricksTasks() + +repair_databricks_package = { + "view": repair_databricks_view, +} + + +class DatabricksWorkflowPlugin(AirflowPlugin): + """ + Databricks Workflows plugin for Airflow. + + .. seealso:: + For more information on how to use this plugin, take a look at the guide: + :ref:`howto/plugin:DatabricksWorkflowPlugin` + """ + + name = "databricks_workflow" + operator_extra_links = [ + WorkflowJobRepairAllFailedLink(), + WorkflowJobRepairSingleTaskLink(), + WorkflowJobRunLink(), + ] + appbuilder_views = [repair_databricks_package] diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 930813ced3284..0132982659884 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -165,5 +165,9 @@ connection-types: - hook-class-name: airflow.providers.databricks.hooks.databricks.DatabricksHook connection-type: databricks +plugins: + - name: databricks_workflow + plugin-class: airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin + extra-links: - airflow.providers.databricks.operators.databricks.DatabricksJobRunLink diff --git a/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png b/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png new file mode 100644 index 0000000000000..e99083f53ffcd Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/workflow_plugin_launch_task.png differ diff --git a/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png b/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png new file mode 100644 index 0000000000000..17a130b944e5f Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/workflow_plugin_single_task.png differ diff --git a/docs/apache-airflow-providers-databricks/index.rst b/docs/apache-airflow-providers-databricks/index.rst index 3358bd8bb1061..4e010d643e794 100644 --- a/docs/apache-airflow-providers-databricks/index.rst +++ b/docs/apache-airflow-providers-databricks/index.rst @@ -36,6 +36,7 @@ Connection types Operators + Plugins .. toctree:: :hidden: diff --git a/docs/apache-airflow-providers-databricks/plugins/index.rst b/docs/apache-airflow-providers-databricks/plugins/index.rst new file mode 100644 index 0000000000000..5ddb65f6f3b3d --- /dev/null +++ b/docs/apache-airflow-providers-databricks/plugins/index.rst @@ -0,0 +1,28 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +Databricks Plugins +================== + + +.. toctree:: + :maxdepth: 1 + :glob: + + * diff --git a/docs/apache-airflow-providers-databricks/plugins/workflow.rst b/docs/apache-airflow-providers-databricks/plugins/workflow.rst new file mode 100644 index 0000000000000..22acd05596791 --- /dev/null +++ b/docs/apache-airflow-providers-databricks/plugins/workflow.rst @@ -0,0 +1,60 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/plugin:DatabricksWorkflowPlugin: + + +DatabricksWorkflowPlugin +======================== + + +Overview +-------- + +The ``DatabricksWorkflowPlugin`` enhances the Airflow UI by providing links for tasks that allow users to view the +Databricks job run in the Databricks workspace. Additionally, it offers links to repair task(s) within the workflow. + +Features +-------- + +- **Task-Level Links**: Within the workflow, each task includes links to the job run and a repair link for the individual task. + +- **Workflow-Level Links**: At the workflow level, for the job launch task, the plugin provides a link to repair all failed tasks and a link to the job run(allows users to monitor the job in the Databricks account) in the Databricks workspace. + +Examples +-------- + +- **Job Run Link and Repair link for Single Task**: + +.. image:: ../img/workflow_plugin_single_task.png + +- **Workflow-Level Links to the job run and to repair all failed tasks**: + +.. image:: ../img/workflow_plugin_launch_task.png + +Notes +----- + +Databricks does not allow repairing jobs with single tasks launched outside the workflow. Hence, for these tasks, only the job run link is provided. + +Usage +----- + +Ideally, installing the provider will also install the plugin, and it should work automatically in your deployment. +However, if custom configurations are preventing the use of plugins, ensure the plugin is properly installed and +configured in your Airflow environment to utilize its features. The plugin will automatically detect Databricks jobs, +as the links are embedded in the relevant operators. diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 9d025228867ae..c9de823c0035a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -435,7 +435,12 @@ "devel-deps": [ "deltalake>=0.12.0" ], - "plugins": [], + "plugins": [ + { + "name": "databricks_workflow", + "plugin-class": "airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin" + } + ], "cross-providers-deps": [ "common.sql" ], diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index b60a9de17999d..a8962b1a61797 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -414,7 +414,7 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): assert len(plugins_manager.plugins) == 0 plugins_manager.load_entrypoint_plugins() plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) == 2 + assert len(plugins_manager.plugins) == 3 class TestPluginsDirectorySource: diff --git a/tests/providers/databricks/plugins/__init__.py b/tests/providers/databricks/plugins/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/databricks/plugins/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/databricks/plugins/test_databricks_workflow.py b/tests/providers/databricks/plugins/test_databricks_workflow.py new file mode 100644 index 0000000000000..ec498f153abe7 --- /dev/null +++ b/tests/providers/databricks/plugins/test_databricks_workflow.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstanceKey +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.databricks.plugins.databricks_workflow import ( + DatabricksWorkflowPlugin, + RepairDatabricksTasks, + WorkflowJobRepairSingleTaskLink, + WorkflowJobRunLink, + _get_dagrun, + _get_databricks_task_id, + _get_launch_task_key, + _repair_task, + get_databricks_task_ids, + get_launch_task_id, + get_task_instance, +) +from airflow.utils.dates import days_ago +from airflow.www.app import create_app + +DAG_ID = "test_dag" +TASK_ID = "test_task" +RUN_ID = "test_run_1" +DAG_RUN_DATE = days_ago(1) +TASK_INSTANCE_KEY = TaskInstanceKey(dag_id=DAG_ID, task_id=TASK_ID, run_id=RUN_ID, try_number=1) +DATABRICKS_CONN_ID = "databricks_default" +DATABRICKS_RUN_ID = 12345 +GROUP_ID = "test_group" +TASK_MAP = { + "task1": MagicMock(dag_id=DAG_ID, task_id="task1"), + "task2": MagicMock(dag_id=DAG_ID, task_id="task2"), +} +LOG = MagicMock() + + +@pytest.mark.parametrize( + "task, expected_id", + [ + (MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"), + (MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"), + ], +) +def test_get_databricks_task_id(task, expected_id): + result = _get_databricks_task_id(task) + + assert result == expected_id + + +def test_get_databricks_task_ids(): + result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG) + + expected_ids = ["test_dag__task1", "test_dag__task2"] + assert result == expected_ids + + +def test_get_dagrun(): + session = MagicMock() + dag = MagicMock(dag_id=DAG_ID) + session.query.return_value.filter.return_value.first.return_value = DagRun() + + result = _get_dagrun(dag, RUN_ID, session=session) + + assert isinstance(result, DagRun) + + +@patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") +def test_repair_task(mock_databricks_hook): + mock_hook_instance = mock_databricks_hook.return_value + mock_hook_instance.get_latest_repair_id.return_value = 100 + mock_hook_instance.repair_run.return_value = 200 + + tasks_to_repair = ["task1", "task2"] + result = _repair_task(DATABRICKS_CONN_ID, DATABRICKS_RUN_ID, tasks_to_repair, LOG) + + assert result == 200 + mock_hook_instance.get_latest_repair_id.assert_called_once_with(DATABRICKS_RUN_ID) + mock_hook_instance.repair_run.assert_called_once() + + +def test_get_launch_task_id_no_launch_task(): + task_group = MagicMock(get_child_by_label=MagicMock(side_effect=KeyError)) + task_group.parent_group = None + + with pytest.raises(AirflowException): + get_launch_task_id(task_group) + + +def test_get_launch_task_key(): + result = _get_launch_task_key(TASK_INSTANCE_KEY, "launch_task") + + assert isinstance(result, TaskInstanceKey) + assert result.dag_id == TASK_INSTANCE_KEY.dag_id + assert result.task_id == "launch_task" + assert result.run_id == TASK_INSTANCE_KEY.run_id + + +@pytest.fixture(scope="session") +def app(): + app = create_app(testing=True) + app.config["SERVER_NAME"] = "localhost" + + with app.app_context(): + yield app + + +def test_get_task_instance(app): + with app.app_context(): + operator = Mock() + operator.dag.dag_id = "dag_id" + operator.task_id = "task_id" + dttm = "2022-01-01T00:00:00Z" + session = Mock() + dag_run = Mock() + session.query().filter().one_or_none.return_value = dag_run + + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.DagRun.find", return_value=[dag_run] + ): + result = get_task_instance(operator, dttm, session) + assert result == dag_run + + +def test_workflow_job_run_link(app): + with app.app_context(): + link = WorkflowJobRunLink() + operator = Mock() + ti_key = Mock() + ti_key.dag_id = "dag_id" + ti_key.task_id = "task_id" + ti_key.run_id = "run_id" + ti_key.try_number = 1 + + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_task_instance" + ) as mock_get_task_instance: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_xcom_result" + ) as mock_get_xcom_result: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.airflow_app.dag_bag.get_dag" + ) as mock_get_dag: + mock_connection = Mock() + mock_connection.extra_dejson = {"host": "mockhost"} + + with patch( + "airflow.providers.databricks.hooks.databricks.DatabricksHook.get_connection", + return_value=mock_connection, + ): + mock_get_task_instance.return_value = Mock(key=ti_key) + mock_get_xcom_result.return_value = Mock(conn_id="conn_id", run_id=1, job_id=1) + mock_get_dag.return_value.get_task = Mock(return_value=Mock(task_id="task_id")) + + result = link.get_link(operator, ti_key=ti_key) + assert "https://mockhost/#job/1/run/1" in result + + +def test_workflow_job_repair_single_failed_link(app): + with app.app_context(): + link = WorkflowJobRepairSingleTaskLink() + operator = Mock() + operator.task_group = Mock() + operator.task_group.group_id = "group_id" + operator.task_group.get_child_by_label = Mock() + ti_key = Mock() + ti_key.dag_id = "dag_id" + ti_key.task_id = "task_id" + ti_key.run_id = "run_id" + ti_key.try_number = 1 + + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_task_instance" + ) as mock_get_task_instance: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_xcom_result" + ) as mock_get_xcom_result: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.airflow_app.dag_bag.get_dag" + ) as mock_get_dag: + mock_get_task_instance.return_value = Mock(key=ti_key) + mock_get_xcom_result.return_value = Mock(conn_id="conn_id", run_id=1) + mock_get_dag.return_value.get_task = Mock(return_value=Mock(task_id="task_id")) + + result = link.get_link(operator, ti_key=ti_key) + assert result.startswith("http://localhost/repair_databricks_job") + + +@pytest.fixture +def plugin(): + return DatabricksWorkflowPlugin() + + +def test_plugin_is_airflow_plugin(plugin): + assert isinstance(plugin, AirflowPlugin) + + +def test_operator_extra_links(plugin): + for link in plugin.operator_extra_links: + assert hasattr(link, "get_link") + + +def test_appbuilder_views(plugin): + assert plugin.appbuilder_views is not None + assert len(plugin.appbuilder_views) == 1 + + repair_view = plugin.appbuilder_views[0]["view"] + assert isinstance(repair_view, RepairDatabricksTasks) + assert repair_view.default_view == "repair"