diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 1a0ab8e8c6ba8..710074d239af0 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -491,6 +491,17 @@ def get_run_output(self, run_id: int) -> dict: run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output + async def a_get_run_output(self, run_id: int) -> dict: + """ + Async version of `get_run_output()`. + + :param run_id: id of the run + :return: output of the run + """ + json = {"run_id": run_id} + run_output = await self._a_do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) + return run_output + def cancel_run(self, run_id: int) -> None: """ Cancel the run. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 1f16e5667b9a5..c38b0683c37b3 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -70,10 +70,9 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: if run_state.result_state == "FAILED": task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] + for task in run_info.get("tasks", []): + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] if task_run_id is not None: run_output = hook.get_run_output(task_run_id) if "error" in run_output: @@ -160,13 +159,15 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) validate_trigger_event(event) run_state = RunState.from_json(event["run_state"]) run_page_url = event["run_page_url"] + errors = event["errors"] log.info("View run status, Spark UI, and logs at %s", run_page_url) if run_state.is_successful: log.info("Job run completed successfully.") return - error_message = f"Job run failed with terminal state: {run_state}" + error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}" + if event["repair_run"]: log.warning( "%s but since repair run is set, repairing the run with all failed tasks", diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index 4c1eecb85f7fd..d20202fdca7f8 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -84,21 +84,36 @@ async def run(self): async with self.hook: while True: run_state = await self.hook.a_get_run_state(self.run_id) - if run_state.is_terminal: - yield TriggerEvent( - { - "run_id": self.run_id, - "run_page_url": self.run_page_url, - "run_state": run_state.to_json(), - "repair_run": self.repair_run, - } + if not run_state.is_terminal: + self.log.info( + "run-id %s in run state %s. sleeping for %s seconds", + self.run_id, + run_state, + self.polling_period_seconds, ) - return + await asyncio.sleep(self.polling_period_seconds) + continue - self.log.info( - "run-id %s in run state %s. sleeping for %s seconds", - self.run_id, - run_state, - self.polling_period_seconds, + failed_tasks = [] + if run_state.result_state == "FAILED": + run_info = await self.hook.a_get_run(self.run_id) + for task in run_info.get("tasks", []): + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + task_key = task["task_key"] + run_output = await self.hook.a_get_run_output(task_run_id) + if "error" in run_output: + error = run_output["error"] + else: + error = run_state.state_message + failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error}) + yield TriggerEvent( + { + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_run": self.repair_run, + "errors": failed_tasks, + } ) - await asyncio.sleep(self.polling_period_seconds) + return diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index 0635017b28f80..88d622c3bc1fb 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict): See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`. """ - keys_to_check = ["run_id", "run_page_url", "run_state"] + keys_to_check = ["run_id", "run_page_url", "run_state", "errors"] for key in keys_to_check: if key not in event: raise AirflowException(f"Could not find `{key}` in the event: {event}") diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 64d4de1d37766..0f1d2c242e456 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -1603,6 +1603,23 @@ async def test_get_cluster_state(self, mock_get): timeout=self.hook.timeout_seconds, ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + async def test_get_run_output(self, mock_get): + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_OUTPUT_RESPONSE) + async with self.hook: + run_output = await self.hook.a_get_run_output(RUN_ID) + run_output_error = run_output.get("error") + + assert run_output_error == ERROR_MESSAGE + mock_get.assert_called_once_with( + get_run_output_endpoint(HOST), + json={"run_id": RUN_ID}, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @pytest.mark.db_test class TestDatabricksHookAsyncAadToken: diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 902aa37e918ea..e6cb240dfc9f3 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1024,6 +1024,7 @@ def test_execute_complete_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), + "errors": [], } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1044,6 +1045,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, + "errors": [], } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1594,6 +1596,7 @@ def test_execute_complete_success(self): "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), "repair_run": False, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1611,6 +1614,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1641,6 +1645,7 @@ def test_execute_complete_failure_and_repair_run( "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": True, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) diff --git a/tests/providers/databricks/triggers/test_databricks.py b/tests/providers/databricks/triggers/test_databricks.py index a0313a31f4fa0..b4bcbc133c0ce 100644 --- a/tests/providers/databricks/triggers/test_databricks.py +++ b/tests/providers/databricks/triggers/test_databricks.py @@ -38,13 +38,22 @@ RETRY_DELAY = 10 RETRY_LIMIT = 3 RUN_ID = 1 +TASK_RUN_ID1 = 11 +TASK_RUN_ID1_KEY = "first_task" +TASK_RUN_ID2 = 22 +TASK_RUN_ID2_KEY = "second_task" +TASK_RUN_ID3 = 33 +TASK_RUN_ID3_KEY = "third_task" JOB_ID = 42 RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" +ERROR_MESSAGE = "error message from databricks API" +GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}} RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"] LIFE_CYCLE_STATE_PENDING = "PENDING" LIFE_CYCLE_STATE_TERMINATED = "TERMINATED" +LIFE_CYCLE_STATE_INTERNAL_ERROR = "INTERNAL_ERROR" STATE_MESSAGE = "Waiting for cluster" @@ -66,6 +75,44 @@ "result_state": "SUCCESS", }, } +GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = { + "job_id": JOB_ID, + "run_page_url": RUN_PAGE_URL, + "state": { + "life_cycle_state": LIFE_CYCLE_STATE_INTERNAL_ERROR, + "state_message": None, + "result_state": "FAILED", + }, + "tasks": [ + { + "run_id": TASK_RUN_ID1, + "task_key": TASK_RUN_ID1_KEY, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "Workload failed, see run output for details", + }, + }, + { + "run_id": TASK_RUN_ID2, + "task_key": TASK_RUN_ID2_KEY, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "SUCCESS", + "state_message": None, + }, + }, + { + "run_id": TASK_RUN_ID3, + "task_key": TASK_RUN_ID3_KEY, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "Workload failed, see run output for details", + }, + }, + ], +} class TestDatabricksExecutionTrigger: @@ -101,15 +148,21 @@ def test_serialize(self): ) @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") - async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url): + async def test_run_return_success( + self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output + ): mock_get_run_page_url.return_value = RUN_PAGE_URL mock_get_run_state.return_value = RunState( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS", ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE trigger_event = self.trigger.run() async for event in trigger_event: @@ -121,13 +174,52 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, + "errors": [], + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_run_return_failure( + self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output + ): + mock_get_run_page_url.return_value = RUN_PAGE_URL + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ) + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + "run_id": RUN_ID, + "run_state": RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED" + ).to_json(), + "run_page_url": RUN_PAGE_URL, + "repair_run": False, + "errors": [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ], } ) @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") - async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): + async def test_sleep_between_retries( + self, mock_get_run_state, mock_sleep, mock_get_run, mock_get_run_output + ): mock_get_run_state.side_effect = [ RunState( life_cycle_state=LIFE_CYCLE_STATE_PENDING, @@ -140,6 +232,8 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): result_state="SUCCESS", ), ] + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE trigger_event = self.trigger.run() async for event in trigger_event: @@ -151,6 +245,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, + "errors": [], } ) mock_sleep.assert_called_once() diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index 7619bcb8ad07f..8c6ce8ce4ba59 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -53,6 +53,7 @@ def test_validate_trigger_event_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), + "errors": [], } assert validate_trigger_event(event) is None