Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions airflow/providers/databricks/plugins/databricks_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import logging
import os
from operator import itemgetter
from typing import TYPE_CHECKING, Any, cast

from flask import current_app, flash, redirect, request, url_for
from flask_appbuilder.api import expose

from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.configuration import conf
from airflow.exceptions import AirflowException, TaskInstanceNotFound
from airflow.models import BaseOperator, BaseOperatorLink
Expand All @@ -39,6 +39,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.www import auth
from airflow.www.views import AirflowBaseView

if TYPE_CHECKING:
Expand Down Expand Up @@ -397,16 +398,15 @@ class RepairDatabricksTasks(AirflowBaseView, LoggingMixin):
default_view = "repair"

@expose("/repair_databricks_job", methods=("GET",))
@auth.has_access_dag("GET", DagAccessEntity.TASK_INSTANCE)
def repair(self):
databricks_conn_id, databricks_run_id, dag_id, tasks_to_repair = itemgetter(
"databricks_conn_id", "databricks_run_id", "dag_id", "tasks_to_repair"
)(request.values)
view = conf.get("webserver", "dag_default_view")
return_url = self._get_return_url(dag_id, view)
run_id = request.values.get("run_id").replace(
" ", "+"
) # get run id separately since we need to modify it
if not tasks_to_repair:
databricks_conn_id = request.values.get("databricks_conn_id")
databricks_run_id = request.values.get("databricks_run_id")
dag_id = request.values.get("dag_id")
tasks_to_repair = request.values.get("tasks_to_repair")
run_id = request.values.get("run_id")
return_url = RepairDatabricksTasks._get_return_url(dag_id) or request.referrer or "/"
if not all([databricks_conn_id, tasks_to_repair, databricks_run_id, run_id]):
# If there are no tasks to repair, we return.
flash("No tasks to repair. Not sending repair request.")
return redirect(return_url)
Expand All @@ -418,23 +418,23 @@ def repair(self):
tasks_to_repair=tasks_to_repair.split(","),
logger=self.log,
)
self.log.info("Repairing databricks job query for run %s sent", databricks_run_id)

self.log.info("Clearing tasks to rerun in airflow")
run_id = run_id.replace(" ", "+") # get run id separately since we need to modify it
_clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log)
flash(f"Databricks repair job is starting!: {res}")
return redirect(return_url)

@staticmethod
def _get_return_url(dag_id: str, view) -> str:
return f"/dags/{dag_id}/{view}"
def _get_return_url(dag_id: str) -> str | None:
view = conf.get("webserver", "dag_default_view")
return f"/dags/{dag_id}/{view}" if dag_id else None


repair_databricks_view = RepairDatabricksTasks()

repair_databricks_package = {
"name": "Repair Databricks View",
"category": "Repair Databricks Plugin",
"category": "Admin",
"view": repair_databricks_view,
}

Expand Down