Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.providers.databricks.plugins.databricks_workflow import (
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
store_databricks_job_run_link,
)
from airflow.providers.databricks.triggers.databricks import (
DatabricksExecutionTrigger,
Expand Down Expand Up @@ -1214,10 +1215,16 @@ def __init__(
super().__init__(**kwargs)

if self._databricks_workflow_task_group is not None:
self.operator_extra_links = (
WorkflowJobRunLink(),
WorkflowJobRepairSingleTaskLink(),
)
# Conditionally set operator_extra_links based on Airflow version. In Airflow 3, only show the job run link.
# In Airflow 2, show the job run link and the repair link.
# TODO: Once we expand the plugin functionality in Airflow 3.1, this can be re-evaluated on how to handle the repair link.
if AIRFLOW_V_3_0_PLUS:
self.operator_extra_links = (WorkflowJobRunLink(),)
else:
self.operator_extra_links = (
WorkflowJobRunLink(),
WorkflowJobRepairSingleTaskLink(),
)
else:
# Databricks does not support repair for non-workflow tasks, hence do not show the repair link.
self.operator_extra_links = (DatabricksJobRunLink(),)
Expand Down Expand Up @@ -1427,6 +1434,15 @@ def execute(self, context: Context) -> None:
)
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id

# Store operator links in XCom for Airflow 3 compatibility
if AIRFLOW_V_3_0_PLUS:
# Store the job run link
store_databricks_job_run_link(
context=context,
metadata=workflow_run_metadata,
logger=self.log,
)
else:
self._launch_job(context=context)
if self.wait_for_termination:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from airflow.providers.databricks.plugins.databricks_workflow import (
WorkflowJobRepairAllFailedLink,
WorkflowJobRunLink,
store_databricks_job_run_link,
)
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.task_group import TaskGroup

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,9 +94,18 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
populated after instantiation using the `add_task` method.
"""

operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink())
template_fields = ("notebook_params", "job_clusters")
caller = "_CreateDatabricksWorkflowOperator"
# Conditionally set operator_extra_links based on Airflow version
if AIRFLOW_V_3_0_PLUS:
# In Airflow 3, disable "Repair All Failed Tasks" since we can't pre-determine failed tasks
operator_extra_links = (WorkflowJobRunLink(),)
else:
# In Airflow 2.x, keep both links
operator_extra_links = ( # type: ignore[assignment]
WorkflowJobRunLink(),
WorkflowJobRepairAllFailedLink(),
)

def __init__(
self,
Expand Down Expand Up @@ -219,6 +230,15 @@ def execute(self, context: Context) -> Any:
run_id,
)

# Store operator links in XCom for Airflow 3 compatibility
if AIRFLOW_V_3_0_PLUS:
# Store the job run link
store_databricks_job_run_link(
context=context,
metadata=self.workflow_run_metadata,
logger=self.log,
)

return {
"conn_id": self.databricks_conn_id,
"job_id": job_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
else:
from airflow.www import auth # type: ignore
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup

Expand All @@ -49,6 +48,7 @@

from airflow.models import BaseOperator
from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator
from airflow.utils.context import Context

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
Expand Down Expand Up @@ -93,32 +93,56 @@ def get_databricks_task_ids(
return task_ids


@provide_session
def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun:
"""
Retrieve the DagRun object associated with the specified DAG and run_id.
# TODO: Need to re-think on how to support the currently unavailable repair functionality in Airflow 3. Probably a
# good time to re-evaluate this would be once the plugin functionality is expanded in Airflow 3.1.
if not AIRFLOW_V_3_0_PLUS:
from airflow.utils.session import NEW_SESSION, provide_session

:param dag: The DAG object associated with the DagRun to retrieve.
:param run_id: The run_id associated with the DagRun to retrieve.
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
:return: The DagRun object associated with the specified DAG and run_id.
"""
if not session:
raise AirflowException("Session not provided.")
@provide_session
def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun:
"""
Retrieve the DagRun object associated with the specified DAG and run_id.

return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()
:param dag: The DAG object associated with the DagRun to retrieve.
:param run_id: The run_id associated with the DagRun to retrieve.
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
:return: The DagRun object associated with the specified DAG and run_id.
"""
if not session:
raise AirflowException("Session not provided.")

return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()

@provide_session
def _clear_task_instances(
dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None
) -> None:
dag_bag = DagBag(read_dags_from_db=True)
dag = dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)
@provide_session
def _clear_task_instances(
dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None
) -> None:
dag_bag = DagBag(read_dags_from_db=True)
dag = dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)

@provide_session
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
dag_id = operator.dag.dag_id
if hasattr(DagRun, "execution_date"): # Airflow 2.x.
dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg]
else:
dag_run = DagRun.find(dag_id, logical_date=dttm)[0]
ti = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.task_id == operator.task_id,
)
.one_or_none()
)
if not ti:
raise TaskInstanceNotFound("Task instance not found")
return ti


def _repair_task(
Expand Down Expand Up @@ -201,27 +225,6 @@ def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> Tas
return current_task_key


@provide_session
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
dag_id = operator.dag.dag_id
if hasattr(DagRun, "execution_date"): # Airflow 2.x.
dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg]
else:
dag_run = DagRun.find(dag_id, logical_date=dttm)[0]
ti = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.task_id == operator.task_id,
)
.one_or_none()
)
if not ti:
raise TaskInstanceNotFound("Task instance not found")
return ti


def get_xcom_result(
ti_key: TaskInstanceKey,
key: str,
Expand All @@ -240,13 +243,41 @@ class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin):

name = "See Databricks Job Run"

@property
def xcom_key(self) -> str:
"""XCom key where the link is stored during task execution."""
return "databricks_job_run_link"

def get_link(
self,
operator: BaseOperator,
dttm=None,
*,
ti_key: TaskInstanceKey | None = None,
) -> str:
if AIRFLOW_V_3_0_PLUS:
# Use public XCom API to get the pre-computed link
try:
link = XCom.get_value(
ti_key=ti_key,
key=self.xcom_key,
)
return link if link else ""
except Exception as e:
self.log.warning("Failed to retrieve Databricks job run link from XCom: %s", e)
return ""
else:
# Airflow 2.x - keep original implementation
return self._get_link_legacy(operator, dttm, ti_key=ti_key)

def _get_link_legacy(
self,
operator: BaseOperator,
dttm=None,
*,
ti_key: TaskInstanceKey | None = None,
) -> str:
"""Legacy implementation for Airflow 2.x."""
if not ti_key:
ti = get_task_instance(operator, dttm)
ti_key = ti.key
Expand All @@ -269,6 +300,30 @@ def get_link(
return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}"


def store_databricks_job_run_link(
context: Context,
metadata: Any,
logger: logging.Logger,
) -> None:
"""
Store the Databricks job run link in XCom during task execution.

This should be called by Databricks operators during their execution.
"""
if not AIRFLOW_V_3_0_PLUS:
return # Only needed for Airflow 3

try:
hook = DatabricksHook(metadata.conn_id)
link = f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}"

# Store the link in XCom for the UI to retrieve as extra link
context["ti"].xcom_push(key="databricks_job_run_link", value=link)
logger.info("Stored Databricks job run link in XCom: %s", link)
except Exception as e:
logger.warning("Failed to store Databricks job run link: %s", e)


class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin):
"""Constructs a link to send a request to repair all failed tasks in the Databricks workflow."""

Expand Down Expand Up @@ -455,13 +510,6 @@ def _get_return_url(dag_id: str, run_id: str) -> str:
return url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id)


repair_databricks_view = RepairDatabricksTasks()

repair_databricks_package = {
"view": repair_databricks_view,
}


class DatabricksWorkflowPlugin(AirflowPlugin):
"""
Databricks Workflows plugin for Airflow.
Expand All @@ -472,9 +520,22 @@ class DatabricksWorkflowPlugin(AirflowPlugin):
"""

name = "databricks_workflow"
operator_extra_links = [
WorkflowJobRepairAllFailedLink(),
WorkflowJobRepairSingleTaskLink(),
WorkflowJobRunLink(),
]
appbuilder_views = [repair_databricks_package]

# Conditionally set operator_extra_links based on Airflow version
if AIRFLOW_V_3_0_PLUS:
# In Airflow 3, disable the links for repair functionality until it is figured out it can be supported
operator_extra_links = [
WorkflowJobRunLink(),
]
else:
# In Airflow 2.x, keep all links including repair all failed tasks
operator_extra_links = [
WorkflowJobRepairAllFailedLink(),
WorkflowJobRepairSingleTaskLink(),
WorkflowJobRunLink(),
]
repair_databricks_view = RepairDatabricksTasks()
repair_databricks_package = {
"view": repair_databricks_view,
}
appbuilder_views = [repair_databricks_package]
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_wait_for_job_to_start(mock_databricks_hook):
mock_hook_instance.get_run_state.assert_called()


@pytest.mark.db_test
def test_execute(mock_databricks_hook, context, mock_task_group):
"""Test that _CreateDatabricksWorkflowOperator.execute runs the task group."""
operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
Expand Down
Loading
Loading