From 8f9309f10ccd0537778e87626ffd35b37cc50cfe Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Wed, 10 Jul 2024 23:55:37 +0200 Subject: [PATCH] bugfix/repair-databricks-plugin --- .../databricks/plugins/databricks_workflow.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/airflow/providers/databricks/plugins/databricks_workflow.py b/airflow/providers/databricks/plugins/databricks_workflow.py index 186f14d02afdb..03f989e899ddb 100644 --- a/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/airflow/providers/databricks/plugins/databricks_workflow.py @@ -19,12 +19,12 @@ import logging import os -from operator import itemgetter from typing import TYPE_CHECKING, Any, cast from flask import current_app, flash, redirect, request, url_for from flask_appbuilder.api import expose +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.configuration import conf from airflow.exceptions import AirflowException, TaskInstanceNotFound from airflow.models import BaseOperator, BaseOperatorLink @@ -39,6 +39,7 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup +from airflow.www import auth from airflow.www.views import AirflowBaseView if TYPE_CHECKING: @@ -397,16 +398,15 @@ class RepairDatabricksTasks(AirflowBaseView, LoggingMixin): default_view = "repair" @expose("/repair_databricks_job", methods=("GET",)) + @auth.has_access_dag("GET", DagAccessEntity.TASK_INSTANCE) def repair(self): - databricks_conn_id, databricks_run_id, dag_id, tasks_to_repair = itemgetter( - "databricks_conn_id", "databricks_run_id", "dag_id", "tasks_to_repair" - )(request.values) - view = conf.get("webserver", "dag_default_view") - return_url = self._get_return_url(dag_id, view) - run_id = request.values.get("run_id").replace( - " ", "+" - ) # get run id separately since we need to modify it - if not tasks_to_repair: + databricks_conn_id = request.values.get("databricks_conn_id") + databricks_run_id = request.values.get("databricks_run_id") + dag_id = request.values.get("dag_id") + tasks_to_repair = request.values.get("tasks_to_repair") + run_id = request.values.get("run_id") + return_url = RepairDatabricksTasks._get_return_url(dag_id) or request.referrer or "/" + if not all([databricks_conn_id, tasks_to_repair, databricks_run_id, run_id]): # If there are no tasks to repair, we return. flash("No tasks to repair. Not sending repair request.") return redirect(return_url) @@ -418,23 +418,23 @@ def repair(self): tasks_to_repair=tasks_to_repair.split(","), logger=self.log, ) - self.log.info("Repairing databricks job query for run %s sent", databricks_run_id) - self.log.info("Clearing tasks to rerun in airflow") + run_id = run_id.replace(" ", "+") # get run id separately since we need to modify it _clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log) flash(f"Databricks repair job is starting!: {res}") return redirect(return_url) @staticmethod - def _get_return_url(dag_id: str, view) -> str: - return f"/dags/{dag_id}/{view}" + def _get_return_url(dag_id: str) -> str | None: + view = conf.get("webserver", "dag_default_view") + return f"/dags/{dag_id}/{view}" if dag_id else None repair_databricks_view = RepairDatabricksTasks() repair_databricks_package = { "name": "Repair Databricks View", - "category": "Repair Databricks Plugin", + "category": "Admin", "view": repair_databricks_view, }