From ae340349ab746522bf68d37147064eb933878134 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 25 Jun 2025 13:40:18 +0530 Subject: [PATCH] Enable DatabricksJobRunLink for Databricks plugin, skip provide_session usage in Airflow3 --- .../databricks/operators/databricks.py | 24 +- .../operators/databricks_workflow.py | 22 +- .../databricks/plugins/databricks_workflow.py | 175 +++++++++---- .../operators/test_databricks_workflow.py | 1 + .../plugins/test_databricks_workflow.py | 245 ++++++++++++++++-- 5 files changed, 377 insertions(+), 90 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 1756f3b50d041..95d5783de0a92 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -42,6 +42,7 @@ from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairSingleTaskLink, WorkflowJobRunLink, + store_databricks_job_run_link, ) from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, @@ -1214,10 +1215,16 @@ def __init__( super().__init__(**kwargs) if self._databricks_workflow_task_group is not None: - self.operator_extra_links = ( - WorkflowJobRunLink(), - WorkflowJobRepairSingleTaskLink(), - ) + # Conditionally set operator_extra_links based on Airflow version. In Airflow 3, only show the job run link. + # In Airflow 2, show the job run link and the repair link. + # TODO: Once we expand the plugin functionality in Airflow 3.1, this can be re-evaluated on how to handle the repair link. + if AIRFLOW_V_3_0_PLUS: + self.operator_extra_links = (WorkflowJobRunLink(),) + else: + self.operator_extra_links = ( + WorkflowJobRunLink(), + WorkflowJobRepairSingleTaskLink(), + ) else: # Databricks does not support repair for non-workflow tasks, hence do not show the repair link. self.operator_extra_links = (DatabricksJobRunLink(),) @@ -1427,6 +1434,15 @@ def execute(self, context: Context) -> None: ) self.databricks_run_id = workflow_run_metadata.run_id self.databricks_conn_id = workflow_run_metadata.conn_id + + # Store operator links in XCom for Airflow 3 compatibility + if AIRFLOW_V_3_0_PLUS: + # Store the job run link + store_databricks_job_run_link( + context=context, + metadata=workflow_run_metadata, + logger=self.log, + ) else: self._launch_job(context=context) if self.wait_for_termination: diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index aa19bac959a94..9f714fa06d95c 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -31,7 +31,9 @@ from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairAllFailedLink, WorkflowJobRunLink, + store_databricks_job_run_link, ) +from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.task_group import TaskGroup if TYPE_CHECKING: @@ -92,9 +94,18 @@ class _CreateDatabricksWorkflowOperator(BaseOperator): populated after instantiation using the `add_task` method. """ - operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink()) template_fields = ("notebook_params", "job_clusters") caller = "_CreateDatabricksWorkflowOperator" + # Conditionally set operator_extra_links based on Airflow version + if AIRFLOW_V_3_0_PLUS: + # In Airflow 3, disable "Repair All Failed Tasks" since we can't pre-determine failed tasks + operator_extra_links = (WorkflowJobRunLink(),) + else: + # In Airflow 2.x, keep both links + operator_extra_links = ( # type: ignore[assignment] + WorkflowJobRunLink(), + WorkflowJobRepairAllFailedLink(), + ) def __init__( self, @@ -219,6 +230,15 @@ def execute(self, context: Context) -> Any: run_id, ) + # Store operator links in XCom for Airflow 3 compatibility + if AIRFLOW_V_3_0_PLUS: + # Store the job run link + store_databricks_job_run_link( + context=context, + metadata=self.workflow_run_metadata, + logger=self.log, + ) + return { "conn_id": self.databricks_conn_id, "job_id": job_id, diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 16ea7a0b6113e..ab927fb0c1722 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -40,7 +40,6 @@ else: from airflow.www import auth # type: ignore from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup @@ -49,6 +48,7 @@ from airflow.models import BaseOperator from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator + from airflow.utils.context import Context if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink @@ -93,32 +93,56 @@ def get_databricks_task_ids( return task_ids -@provide_session -def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun: - """ - Retrieve the DagRun object associated with the specified DAG and run_id. +# TODO: Need to re-think on how to support the currently unavailable repair functionality in Airflow 3. Probably a +# good time to re-evaluate this would be once the plugin functionality is expanded in Airflow 3.1. +if not AIRFLOW_V_3_0_PLUS: + from airflow.utils.session import NEW_SESSION, provide_session - :param dag: The DAG object associated with the DagRun to retrieve. - :param run_id: The run_id associated with the DagRun to retrieve. - :param session: The SQLAlchemy session to use for the query. If None, uses the default session. - :return: The DagRun object associated with the specified DAG and run_id. - """ - if not session: - raise AirflowException("Session not provided.") + @provide_session + def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun: + """ + Retrieve the DagRun object associated with the specified DAG and run_id. - return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() + :param dag: The DAG object associated with the DagRun to retrieve. + :param run_id: The run_id associated with the DagRun to retrieve. + :param session: The SQLAlchemy session to use for the query. If None, uses the default session. + :return: The DagRun object associated with the specified DAG and run_id. + """ + if not session: + raise AirflowException("Session not provided.") + return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() -@provide_session -def _clear_task_instances( - dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None -) -> None: - dag_bag = DagBag(read_dags_from_db=True) - dag = dag_bag.get_dag(dag_id) - log.debug("task_ids %s to clear", str(task_ids)) - dr: DagRun = _get_dagrun(dag, run_id, session=session) - tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids] - clear_task_instances(tis_to_clear, session) + @provide_session + def _clear_task_instances( + dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None + ) -> None: + dag_bag = DagBag(read_dags_from_db=True) + dag = dag_bag.get_dag(dag_id) + log.debug("task_ids %s to clear", str(task_ids)) + dr: DagRun = _get_dagrun(dag, run_id, session=session) + tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids] + clear_task_instances(tis_to_clear, session) + + @provide_session + def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: + dag_id = operator.dag.dag_id + if hasattr(DagRun, "execution_date"): # Airflow 2.x. + dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg] + else: + dag_run = DagRun.find(dag_id, logical_date=dttm)[0] + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == dag_run.run_id, + TaskInstance.task_id == operator.task_id, + ) + .one_or_none() + ) + if not ti: + raise TaskInstanceNotFound("Task instance not found") + return ti def _repair_task( @@ -201,27 +225,6 @@ def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> Tas return current_task_key -@provide_session -def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: - dag_id = operator.dag.dag_id - if hasattr(DagRun, "execution_date"): # Airflow 2.x. - dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg] - else: - dag_run = DagRun.find(dag_id, logical_date=dttm)[0] - ti = ( - session.query(TaskInstance) - .filter( - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == dag_run.run_id, - TaskInstance.task_id == operator.task_id, - ) - .one_or_none() - ) - if not ti: - raise TaskInstanceNotFound("Task instance not found") - return ti - - def get_xcom_result( ti_key: TaskInstanceKey, key: str, @@ -240,6 +243,11 @@ class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin): name = "See Databricks Job Run" + @property + def xcom_key(self) -> str: + """XCom key where the link is stored during task execution.""" + return "databricks_job_run_link" + def get_link( self, operator: BaseOperator, @@ -247,6 +255,29 @@ def get_link( *, ti_key: TaskInstanceKey | None = None, ) -> str: + if AIRFLOW_V_3_0_PLUS: + # Use public XCom API to get the pre-computed link + try: + link = XCom.get_value( + ti_key=ti_key, + key=self.xcom_key, + ) + return link if link else "" + except Exception as e: + self.log.warning("Failed to retrieve Databricks job run link from XCom: %s", e) + return "" + else: + # Airflow 2.x - keep original implementation + return self._get_link_legacy(operator, dttm, ti_key=ti_key) + + def _get_link_legacy( + self, + operator: BaseOperator, + dttm=None, + *, + ti_key: TaskInstanceKey | None = None, + ) -> str: + """Legacy implementation for Airflow 2.x.""" if not ti_key: ti = get_task_instance(operator, dttm) ti_key = ti.key @@ -269,6 +300,30 @@ def get_link( return f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}" +def store_databricks_job_run_link( + context: Context, + metadata: Any, + logger: logging.Logger, +) -> None: + """ + Store the Databricks job run link in XCom during task execution. + + This should be called by Databricks operators during their execution. + """ + if not AIRFLOW_V_3_0_PLUS: + return # Only needed for Airflow 3 + + try: + hook = DatabricksHook(metadata.conn_id) + link = f"https://{hook.host}/#job/{metadata.job_id}/run/{metadata.run_id}" + + # Store the link in XCom for the UI to retrieve as extra link + context["ti"].xcom_push(key="databricks_job_run_link", value=link) + logger.info("Stored Databricks job run link in XCom: %s", link) + except Exception as e: + logger.warning("Failed to store Databricks job run link: %s", e) + + class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin): """Constructs a link to send a request to repair all failed tasks in the Databricks workflow.""" @@ -455,13 +510,6 @@ def _get_return_url(dag_id: str, run_id: str) -> str: return url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id) -repair_databricks_view = RepairDatabricksTasks() - -repair_databricks_package = { - "view": repair_databricks_view, -} - - class DatabricksWorkflowPlugin(AirflowPlugin): """ Databricks Workflows plugin for Airflow. @@ -472,9 +520,22 @@ class DatabricksWorkflowPlugin(AirflowPlugin): """ name = "databricks_workflow" - operator_extra_links = [ - WorkflowJobRepairAllFailedLink(), - WorkflowJobRepairSingleTaskLink(), - WorkflowJobRunLink(), - ] - appbuilder_views = [repair_databricks_package] + + # Conditionally set operator_extra_links based on Airflow version + if AIRFLOW_V_3_0_PLUS: + # In Airflow 3, disable the links for repair functionality until it is figured out it can be supported + operator_extra_links = [ + WorkflowJobRunLink(), + ] + else: + # In Airflow 2.x, keep all links including repair all failed tasks + operator_extra_links = [ + WorkflowJobRepairAllFailedLink(), + WorkflowJobRepairSingleTaskLink(), + WorkflowJobRunLink(), + ] + repair_databricks_view = RepairDatabricksTasks() + repair_databricks_package = { + "view": repair_databricks_view, + } + appbuilder_views = [repair_databricks_package] diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 0a4c3ee1f9438..68ae71b8e1c9e 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -133,6 +133,7 @@ def test_wait_for_job_to_start(mock_databricks_hook): mock_hook_instance.get_run_state.assert_called() +@pytest.mark.db_test def test_execute(mock_databricks_hook, context, mock_task_group): """Test that _CreateDatabricksWorkflowOperator.execute runs the task group.""" operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") diff --git a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py index f79dbcbcaeb12..e9012742f23f3 100644 --- a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py @@ -17,20 +17,11 @@ from __future__ import annotations -from unittest.mock import MagicMock, Mock, patch +import logging +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest -from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - pytest.skip( - "``airflow/providers/databricks/plugins/databricks_workflow.py`` is only compatible with Airflow 2.X.", - allow_module_level=True, - ) - -from flask import url_for - from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstanceKey @@ -40,14 +31,13 @@ RepairDatabricksTasks, WorkflowJobRepairSingleTaskLink, WorkflowJobRunLink, - _get_dagrun, _get_launch_task_key, _repair_task, get_databricks_task_ids, get_launch_task_id, - get_task_instance, + store_databricks_job_run_link, ) -from airflow.www.app import create_app +from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES @@ -64,6 +54,8 @@ "task2": MagicMock(dag_id=DAG_ID, task_id="task2", databricks_task_key="task_key2"), } +logger = logging.getLogger(__name__) + def test_get_databricks_task_ids(): result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG) @@ -72,7 +64,10 @@ def test_get_databricks_task_ids(): assert result == expected_ids -def test_get_dagrun(): +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +def test_get_dagrun_airflow2(): + from airflow.providers.databricks.plugins.databricks_workflow import _get_dagrun + session = MagicMock() dag = MagicMock(dag_id=DAG_ID) session.query.return_value.filter.return_value.first.return_value = DagRun() @@ -82,6 +77,7 @@ def test_get_dagrun(): assert isinstance(result, DagRun) +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") @patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") def test_repair_task(mock_databricks_hook): mock_hook_instance = mock_databricks_hook.return_value @@ -96,6 +92,7 @@ def test_repair_task(mock_databricks_hook): mock_hook_instance.repair_run.assert_called_once() +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") @patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") def test_repair_task_with_params(mock_databricks_hook): mock_hook_instance = mock_databricks_hook.return_value @@ -143,16 +140,15 @@ def test_get_launch_task_key(): assert result.run_id == TASK_INSTANCE_KEY.run_id -@pytest.fixture(scope="session") -def app(): +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +@pytest.mark.db_test +def test_get_task_instance_airflow2(): + from airflow.providers.databricks.plugins.databricks_workflow import get_task_instance + from airflow.www.app import create_app + app = create_app(testing=True) app.config["SERVER_NAME"] = "localhost" - with app.app_context(): - yield app - - -def test_get_task_instance(app): with app.app_context(): operator = Mock() operator.dag.dag_id = "dag_id" @@ -169,18 +165,32 @@ def test_get_task_instance(app): assert result == dag_run -def test_get_return_url_dag_id_run_id(app): +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +@pytest.mark.db_test +def test_get_return_url_dag_id_run_id_airflow2(): + from flask import url_for + + from airflow.www.app import create_app + dag_id = "example_dag" run_id = "example_run" - expected_url = url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id) - + app = create_app(testing=True) + app.config["SERVER_NAME"] = "localhost" with app.app_context(): + expected_url = url_for("Airflow.grid", dag_id=dag_id, dag_run_id=run_id) actual_url = RepairDatabricksTasks._get_return_url(dag_id, run_id) assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" -def test_workflow_job_run_link(app): +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +@pytest.mark.db_test +def test_workflow_job_run_link_airflow2(): + from airflow.www.app import create_app + + app = create_app(testing=True) + app.config["SERVER_NAME"] = "localhost" + with app.app_context(): link = WorkflowJobRunLink() operator = Mock() @@ -214,10 +224,16 @@ def test_workflow_job_run_link(app): assert "https://mockhost/#job/1/run/1" in result +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") @pytest.mark.skipif( RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES, reason="Web plugin test doesn't work when not against sources" ) -def test_workflow_job_repair_single_failed_link(app): +@pytest.mark.db_test +def test_workflow_job_repair_single_failed_link_airflow2(): + from airflow.www.app import create_app + + app = create_app(testing=True) + app.config["SERVER_NAME"] = "localhost" with app.app_context(): link = WorkflowJobRepairSingleTaskLink() operator = Mock() @@ -261,10 +277,183 @@ def test_operator_extra_links(plugin): assert hasattr(link, "get_link") -def test_appbuilder_views(plugin): +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +def test_appbuilder_views_airflow2(plugin): assert plugin.appbuilder_views is not None assert len(plugin.appbuilder_views) == 1 repair_view = plugin.appbuilder_views[0]["view"] assert isinstance(repair_view, RepairDatabricksTasks) assert repair_view.default_view == "repair" + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3.0+") +class TestDatabricksWorkflowPluginAirflow3: + """Test Databricks Workflow Plugin functionality specific to Airflow 3.x.""" + + def test_plugin_operator_extra_links_limited_functionality(self): + """Test that operator_extra_links are limited in Airflow 3.x (only job run link).""" + plugin = DatabricksWorkflowPlugin() + + # In Airflow 3, only WorkflowJobRunLink should be present + assert len(plugin.operator_extra_links) == 1 + assert isinstance(plugin.operator_extra_links[0], WorkflowJobRunLink) + + # Verify repair links are not present + link_types = [type(link).__name__ for link in plugin.operator_extra_links] + assert not any("Repair" in link_type for link_type in link_types) + + def test_plugin_no_appbuilder_views(self): + """Test that appbuilder_views are not configured in Airflow 3.x.""" + plugin = DatabricksWorkflowPlugin() + + # In Airflow 3, appbuilder_views should not be set (repair functionality disabled) + assert not getattr(plugin, "appbuilder_views", []) + + def test_store_databricks_job_run_link_function_works(self): + """Test that store_databricks_job_run_link works correctly in Airflow 3.x.""" + ti_mock = Mock() + ti_mock.xcom_push = Mock() + + context = { + "ti": ti_mock, + "dag": Mock(dag_id="test_dag"), + "dag_run": Mock(run_id="test_run"), + "task": Mock(task_id="test_task"), + } + + metadata = Mock(conn_id="databricks_default", job_id=12345, run_id=67890) + + with patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") as mock_hook: + mock_hook_instance = Mock() + mock_hook_instance.host = "test-databricks-host" + mock_hook.return_value = mock_hook_instance + + store_databricks_job_run_link(context, metadata, logger) + + ti_mock.xcom_push.assert_called_once() + + call_args = ti_mock.xcom_push.call_args + assert call_args[1]["key"] == "databricks_job_run_link" + assert "test-databricks-host" in call_args[1]["value"] + assert "12345" in call_args[1]["value"] + assert "67890" in call_args[1]["value"] + assert ti_mock.xcom_push.call_count == 1 + + def test_workflow_job_run_link_uses_xcom(self): + """Test that WorkflowJobRunLink.get_link uses XCom in Airflow 3.x.""" + link = WorkflowJobRunLink() + operator = Mock() + ti_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run", try_number=1) + + expected_link = "https://test-host/#job/123/run/456" + + with patch("airflow.providers.databricks.plugins.databricks_workflow.XCom") as mock_xcom: + mock_xcom.get_value.return_value = expected_link + + result = link.get_link(operator, ti_key=ti_key) + + mock_xcom.get_value.assert_called_once_with(ti_key=ti_key, key="databricks_job_run_link") + + assert result == expected_link + + def test_store_databricks_job_run_link_exception_handling(self): + """Test that exceptions are properly handled in store_databricks_job_run_link.""" + ti_mock = Mock() + ti_mock.xcom_push = Mock() + + context = { + "ti": ti_mock, + "dag": Mock(dag_id="test_dag"), + "dag_run": Mock(run_id="test_run"), + "task": Mock(task_id="test_task"), + } + + metadata = Mock(conn_id="databricks_default", job_id=12345, run_id=67890) + + with patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook") as mock_hook: + mock_hook_instance = Mock() + type(mock_hook_instance).host = PropertyMock(side_effect=Exception("Connection failed")) + mock_hook.return_value = mock_hook_instance + + store_databricks_job_run_link(context, metadata, logger) + + # Verify no XCom was pushed due to the exception + ti_mock.xcom_push.assert_not_called() + + +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow < 3.0") +class TestDatabricksWorkflowPluginAirflow2: + """Test Databricks Workflow Plugin functionality specific to Airflow 2.x.""" + + def test_plugin_operator_extra_links_full_functionality(self): + """Test that all operator_extra_links are present in Airflow 2.x.""" + plugin = DatabricksWorkflowPlugin() + + # In Airflow 2.x, all links should be present including repair links + assert len(plugin.operator_extra_links) >= 2 # At least job run link + repair links + link_types = [type(link).__name__ for link in plugin.operator_extra_links] + assert "WorkflowJobRunLink" in link_types + # Should have repair links in 2.x + assert any("Repair" in link_type for link_type in link_types) + + def test_plugin_has_appbuilder_views(self): + """Test that appbuilder_views are configured for repair functionality in Airflow 2.x.""" + plugin = DatabricksWorkflowPlugin() + + # In Airflow 2.x, appbuilder_views should be present for repair functionality + assert hasattr(plugin, "appbuilder_views") + assert plugin.appbuilder_views is not None + + def test_store_databricks_job_run_link_returns_early(self): + """Test that store_databricks_job_run_link returns early in Airflow 2.x.""" + ti_mock = Mock() + ti_mock.xcom_push = Mock() + + context = { + "ti": ti_mock, + "dag": Mock(dag_id="test_dag"), + "dag_run": Mock(run_id="test_run"), + "task": Mock(task_id="test_task"), + } + + metadata = Mock(conn_id="databricks_default", job_id=12345, run_id=67890) + + store_databricks_job_run_link(context, metadata, logger) + + ti_mock.xcom_push.assert_not_called() + + def test_workflow_job_run_link_uses_legacy_method(self): + """Test that WorkflowJobRunLink.get_link uses legacy method in Airflow 2.x.""" + link = WorkflowJobRunLink() + operator = Mock() + operator.task_group = Mock() + operator.task_group.group_id = "test_group" + + ti_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run", try_number=1) + + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_task_instance" + ) as mock_get_ti: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.get_xcom_result" + ) as mock_get_xcom: + with patch("airflow.providers.databricks.plugins.databricks_workflow.DagBag") as mock_dag_bag: + with patch( + "airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook" + ) as mock_hook: + mock_get_ti.return_value = Mock(key=ti_key) + mock_get_xcom.return_value = Mock(conn_id="conn_id", run_id=1, job_id=1) + mock_dag_bag.return_value.get_dag.return_value.get_task.return_value = Mock( + task_id="test_task" + ) + + mock_hook_instance = Mock() + mock_hook_instance.host = "test-host" + mock_hook.return_value = mock_hook_instance + + result = link.get_link(operator, ti_key=ti_key) + + # Verify legacy method was used (should contain databricks host) + assert "test-host" in result + assert "#job/1/run/1" in result